Merge remote-tracking branch 'origin/dev' into 4.0.0-dev

This commit is contained in:
Marek Wolan
2025-02-10 14:39:28 +00:00
parent 0d1edf0362
commit 96549e68aa
71 changed files with 2700 additions and 367 deletions

View File

@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [4.0.0] = TBC
### Added
- Log observation space data by episode and step.
- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted.
- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only)
- Added ability to set the observation threshold for NMNE, file access and application executions
### Changed
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
@@ -24,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated tests that don't use YAMLs to still use the new action and agent schemas
- Nodes now use a config schema and are extensible, allowing for plugin support.
- Node tests have been updated to use the new node config schemas when not using YAML files.
- ACLs are no longer applied to layer-2 traffic.
- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file
or `generate_seed_value` is set to `true`.
- ARP .show() method will now include the port number associated with each entry.
- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning).
- Updated the `Terminal` class to provide response information when sending remote command execution.
### Fixed
- DNS client no longer fails to check its cache if a DNS server address is missing.

View File

@@ -21,7 +21,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
team: GREEN
type: probabilistic-agent
observation_space:
type: UC2GreenObservation
type: UC2GreenObservation # TODO: what
action_space:
reward_function:
reward_components:
@@ -160,3 +160,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef
-----------------
Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation.
A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``.

View File

@@ -54,6 +54,39 @@ Optional. Default value is ``3``.
The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``.
``file_system``
---------------
Optional.
The file system of the node. This configuration allows nodes to be initialised with files and/or folders.
The file system takes a list of folders and files.
Example:
.. code-block:: yaml
simulation:
network:
nodes:
- hostname: client_1
type: computer
ip_address: 192.168.10.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
file_system:
- empty_folder # example of an empty folder
- downloads:
- "test_1.txt" # files in the downloads folder
- "test_2.txt"
- root:
- passwords: # example of file with size and type
size: 69 # size in bytes
type: TXT # See FileType for list of available file types
List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType`
``users``
---------

View File

@@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules.
some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv")
some_tech_storage_srv.file_system.create_file(file_name="test.png")
pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"]
pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"]
pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"]
pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"]
assert not pc_1_ftp_client.request_file(
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
@@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
web_server: Server = network.get_node_by_hostname("some_tech_web_srv")
web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"]
web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"]
assert not web_ftp_client.request_file(
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
@@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
some_tech_storage_srv.file_system.create_file(file_name="test.png")
some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc")
snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"]
snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"]
assert snr_dev_ftp_client.request_file(
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
@@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
some_tech_storage_srv.file_system.create_file(file_name="test.png")
some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc")
jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"]
jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"]
assert not jnr_dev_ftp_client.request_file(
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
@@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
some_tech_storage_srv.file_system.create_file(file_name="test.png")
some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1")
hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"]
hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"]
assert not hr_ftp_client.request_file(
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,

View File

@@ -74,7 +74,7 @@ The subnet mask setting for the port.
``acl``
-------
Sets up the ACL rules for the router.
Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP.
e.g.
@@ -85,10 +85,6 @@ e.g.
...
acl:
1:
action: PERMIT
src_port: ARP
dst_port: ARP
2:
action: PERMIT
protocol: ICMP

View File

@@ -46,17 +46,13 @@ The core features that should be implemented in any new agent are detailed below
- ref: example_green_agent
team: GREEN
type: ExampleAgent
type: example-agent
action_space:
action_map:
0:
action: do-nothing
options: {}
reward_function:
reward_components:
- type: dummy
agent_settings:
start_step: 25
frequency: 20

View File

@@ -26,9 +26,9 @@ class Router(NetworkNode, identifier="router"):
""" Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces."""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
"user-session-manager": UserSessionManager,
"user-manager": UserManager,
"terminal": Terminal,
}
network_interfaces: Dict[str, RouterInterface] = {}
@@ -52,4 +52,4 @@ class Router(NetworkNode, identifier="router"):
Changes to YAML file.
=====================
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.

View File

@@ -2,6 +2,8 @@
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _request_system:
Request System
**************

View File

@@ -97,19 +97,19 @@ we'll use the following Network that has a client, server, two switches, and a r
network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1])
network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1])
8. Add ACL rules on the Router to allow ARP and ICMP traffic.
8. Add an ACL rule on the Router to allow ICMP traffic.
.. code-block:: python
router_1.acl.add_rule(
action=ACLAction.PERMIT,
src_port=Port["ARP"],
dst_port=Port["ARP"],
src_port=PORT_LOOKUP["ARP"],
dst_port=PORT_LOOKUP["ARP"],
position=22
)
router_1.acl.add_rule(
action=ACLAction.PERMIT,
protocol=IPProtocol["ICMP"],
protocol=PROTOCOL_LOOKUP["ICMP"],
position=23
)

View File

@@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality.
network.connect(pc_a.network_interface[1], router_1.router_interface)
# Configure Router 1 ACLs
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23)
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
# Configure PC B
pc_b = Computer(

View File

@@ -183,7 +183,7 @@ Python
# Example command: Installing and configuring Ransomware:
ransomware_installation_command = { "commands": [
["software_manager","application","install","RansomwareScript"],
["software_manager","application","install","ransomware-script"],
],
"username": "admin",
"password": "admin",

View File

@@ -77,7 +77,7 @@ Python
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
data_manipulation_bot.run()
@@ -98,7 +98,7 @@ If not using the data manipulation bot manually, it needs to be used with a data
type: red-database-corrupting-agent
observation_space:
type: UC2RedObservation
type: uc2-red-observation #TODO what
options:
nodes:
- node_name: client_1

View File

@@ -59,7 +59,7 @@ Python
# install DatabaseClient
client.software_manager.install(DatabaseClient)
database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient")
database_client: DatabaseClient = client.software_manager.software.get("database-sclient")
# Configure the DatabaseClient
database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService

View File

@@ -62,7 +62,7 @@ Python
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
client_1.software_manager.install(RansomwareScript)
RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript")
RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script")
RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14"))
RansomwareScript.execute()

View File

@@ -61,7 +61,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D
# Install WebBrowser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
web_browser.run()
# configure the WebBrowser

View File

@@ -66,7 +66,7 @@ Python
# Install DatabaseService on server
server.software_manager.install(DatabaseService)
db_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_service: DatabaseService = server.software_manager.software.get("database-service")
db_service.start()
# configure DatabaseService

View File

@@ -56,7 +56,7 @@ Python
# Install DNSClient on server
server.software_manager.install(DNSClient)
dns_client: DNSClient = server.software_manager.software.get("DNSClient")
dns_client: DNSClient = server.software_manager.software.get("dns-client")
dns_client.start()
# configure DatabaseService

View File

@@ -53,7 +53,7 @@ Python
# Install DNSServer on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
dns_server: DNSServer = server.software_manager.software.get("dns-server")
dns_server.start()
# configure DatabaseService

View File

@@ -60,7 +60,7 @@ Python
# Install FTPClient on server
server.software_manager.install(FTPClient)
ftp_client: FTPClient = server.software_manager.software.get("FTPClient")
ftp_client: FTPClient = server.software_manager.software.get("ftp-client")
ftp_client.start()

View File

@@ -55,7 +55,7 @@ Python
# Install FTPServer on server
server.software_manager.install(FTPServer)
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
ftp_server: FTPServer = server.software_manager.software.get("ftp-server")
ftp_server.start()
ftp_server.server_password = "test"

View File

@@ -53,7 +53,7 @@ Python
# Install NTPClient on server
server.software_manager.install(NTPClient)
ntp_client: NTPClient = server.software_manager.software.get("NTPClient")
ntp_client: NTPClient = server.software_manager.software.get("ntp-client")
ntp_client.start()
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10"))

View File

@@ -55,7 +55,7 @@ Python
# Install NTPServer on server
server.software_manager.install(NTPServer)
ntp_server: NTPServer = server.software_manager.software.get("NTPServer")
ntp_server: NTPServer = server.software_manager.software.get("ntp-server")
ntp_server.start()

View File

@@ -23,6 +23,14 @@ Key capabilities
- Simulates common Terminal processes/commands.
- Leverages the Service base class for install/uninstall, status tracking etc.
Usage
"""""
- Pre-Installs on any `Node` component (with the exception of `Switches`).
- Terminal Clients connect, execute commands and disconnect from remote nodes.
- Ensures that users are logged in to the component before executing any commands.
- Service runs on SSH port 22 by default.
- Enables Agents to send commands both remotely and locally.
Implementation
""""""""""""""
@@ -30,19 +38,112 @@ Implementation
- Manages remote connections in a dictionary by session ID.
- Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate.
- Extends Service class.
- A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
Command Format
^^^^^^^^^^^^^^
Terminals implement their commands through leveraging the pre-existing :ref:`request_system`.
Due to this Terminals will only accept commands passed within the ``RequestFormat``.
:py:class:`primaite.game.interface.RequestFormat`
For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows:
.. code-block:: yaml
command:
- "file_system"
- "create"
- "file"
- "downloads"
- "cat.png"
- "False
This is then loaded from yaml into a dictionary containing the terminal command:
.. code-block:: python
{"command":["file_system", "create", "file", "downloads", "cat.png", "False"]}
Which is then passed to the ``Terminals`` Request Manager to be executed.
Game Layer Usage (Agents)
========================
The below code examples demonstrate how to use terminal related actions in yaml files.
yaml
""""
``node-send-local-command``
"""""""""""""""""""""""""""
Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``).
.. code-block:: yaml
...
...
action: node-send-local-command
options:
node_id: 0
username: admin
password: admin
command: # Example command - Creates a file called 'cat.png' in the downloads folder.
- "file_system"
- "create"
- "file"
- "downloads"
- "cat.png"
- "False"
Usage
"""""
``node-session-remote-login``
"""""""""""""""""
- Pre-Installs on all ``Nodes`` (with the exception of ``Switches``).
- Terminal Clients connect, execute commands and disconnect from remote nodes.
- Ensures that users are logged in to the component before executing any commands.
- Service runs on SSH port 22 by default.
Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts.
.. code-block:: yaml
...
...
action: node-session-remote-login
options:
node_id: 0
username: admin
password: admin
remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh)
``node-send-remote-command``
""""""""""""""""""""""""""""
After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely.
.. code-block:: yaml
...
...
action: node-send-remote-command
options:
node_id: 0
remote_ip: 192.168.0.10
command:
- "file_system"
- "create"
- "file"
- "downloads"
- "cat.png"
- "False"
Simulation Layer Usage
======================
Usage
=====
The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node.
@@ -65,7 +166,7 @@ Python
operating_state=NodeOperatingState.ON,
)
terminal: Terminal = client.software_manager.software.get("Terminal")
terminal: Terminal = client.software_manager.software.get("terminal")
Creating Remote Terminal Connection
"""""""""""""""""""""""""""""""""""
@@ -86,7 +187,7 @@ Creating Remote Terminal Connection
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
@@ -112,12 +213,12 @@ Executing a basic application install command
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"])
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"])
@@ -140,7 +241,7 @@ Creating a folder on a remote node
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
@@ -167,7 +268,7 @@ Disconnect from Remote Node
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")

View File

@@ -56,7 +56,7 @@ Python
# Install WebServer on server
server.software_manager.install(WebServer)
web_server: WebServer = server.software_manager.software.get("WebServer")
web_server: WebServer = server.software_manager.software.get("web-server")
web_server.start()
Via Configuration

View File

@@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down`
node.software_manager.install(WebServer)
web_server: WebServer = node.software_manager.software.get("WebServer")
web_server: WebServer = node.software_manager.software.get("web-server")
assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install
node.power_off()

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC, abstractmethod
from typing import ClassVar, List, Optional, Union
from typing import ClassVar, List, Literal, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"):
"""Action which performs an nmap network service recon (ping scan followed by port scan)."""
config: "NodeNetworkServiceReconAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration schema for NodeNetworkServiceReconAction."""
@@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-
"show": config.show,
},
]
class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-account-add-user"] = "node-account-add-user"
node_name: str
username: str
password: str
is_admin: bool
@classmethod
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"user-manager",
"add_user",
config.username,
config.password,
config.is_admin,
]
class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-account-disable-user"] = "node-account-disable-user"
node_name: str
username: str
@classmethod
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"user-manager",
"disable_user",
config.username,
]
class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-send-local-command"] = "node-send-local-command"
node_name: str
username: str
password: str
command: RequestFormat
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"terminal",
"send_local_command",
config.username,
config.password,
{"command": config.command},
]

View File

@@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC):
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"):
"""Action which performs a remote session login."""
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
@@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
config.node_name,
"service",
"terminal",
"node-session-remote-login",
"node_session_remote_login",
config.username,
config.password,
config.remote_ip,
@@ -63,8 +61,6 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node-session-remote-logoff"):
"""Action which performs a remote session logout."""
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
@@ -78,14 +74,13 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="n
return ["network", "node", config.node_name, "service", "terminal", config.verb, config.remote_ip]
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, discriminator="node-account-change-password"):
class NodeAccountChangePasswordAction(AbstractAction, discriminator="node-account-change-password"):
"""Action which changes the password for a user."""
config: "NodeAccountChangePasswordAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeAccountsChangePasswordAction."""
node_name: str
username: str
current_password: str
new_password: str

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from prettytable import PrettyTable
from pydantic import BaseModel, ConfigDict, Field
from primaite.game.agent.actions import ActionManager
@@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
observation: Optional[ObsType] = None
"""The observation space data for this step."""
class AbstractAgent(BaseModel, ABC):
"""Base class for scripted and RL agents."""
@@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC):
default_factory=lambda: ObservationManager.ConfigSchema()
)
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
thresholds: Optional[Dict] = {}
# TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085)
"""A dict containing the observation thresholds."""
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
@@ -90,10 +97,42 @@ class AbstractAgent(BaseModel, ABC):
def model_post_init(self, __context: Any) -> None:
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
self.action_manager = ActionManager(config=self.config.action_space)
self.config.observation_space.options.thresholds = self.config.thresholds
self.observation_manager = ObservationManager(config=self.config.observation_space)
self.reward_function = RewardFunction(config=self.config.reward_function)
return super().model_post_init(__context)
def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable:
"""Update the given table with information from given AgentHistoryItem."""
node, application = "unknown", "unknown"
if (node_id := item.parameters.get("node_id")) is not None:
node = self.action_manager.node_names[node_id]
if (application_id := item.parameters.get("application_id")) is not None:
application = self.action_manager.application_names[node_id][application_id]
if (application_name := item.parameters.get("application_name")) is not None:
application = application_name
table.add_row([item.timestep, item.action, node, application, item.response.status])
return table
def show_history(self, ignored_actions: Optional[list] = None):
"""
Print an agent action provided it's not the DONOTHING action.
:param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history.
If not provided, defaults to ignore DONOTHING actions.
"""
if not ignored_actions:
ignored_actions = ["DONOTHING"]
table = PrettyTable()
table.field_names = ["Step", "Action", "Node", "Application", "Response"]
print(f"Actions for '{self.agent_name}':")
for item in self.history:
if item.action in ignored_actions:
pass
else:
table = self.add_agent_action(item=item, table=table)
print(table)
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
@@ -140,12 +179,23 @@ class AbstractAgent(BaseModel, ABC):
return request
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
self,
timestep: int,
action: str,
parameters: Dict[str, Any],
request: RequestFormat,
response: RequestResponse,
observation: ObsType,
) -> None:
"""Process the response from the most recent action."""
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
timestep=timestep,
action=action,
parameters=parameters,
request=request,
response=response,
observation=observation,
)
)

View File

@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"):
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
def __init__(
self,
where: WhereType,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a file observation instance.
@@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"):
if self.include_num_access:
self.default_observation["num_access"] = 0
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("file_access") is None:
self.low_file_access_threshold = 0
self.med_file_access_threshold = 5
self.high_file_access_threshold = 10
else:
self._set_file_access_threshold(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
]
)
def _set_file_access_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the file access threshold.
:param: thresholds: The file access threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="file_access",
):
self.low_file_access_threshold = thresholds[0]
self.med_file_access_threshold = thresholds[1]
self.high_file_access_threshold = thresholds[2]
def _categorise_num_access(self, num_access: int) -> int:
"""
@@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"):
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
if num_access > self.high_file_access_threshold:
return 3
elif num_access > self.med_threshold:
elif num_access > self.med_file_access_threshold:
return 2
elif num_access > self.low_threshold:
elif num_access > self.low_file_access_threshold:
return 1
return 0
@@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"):
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)
@@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a folder observation instance.
@@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
thresholds=thresholds,
)
)
while len(self.files) > num_files:
@@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
file_config.thresholds = config.thresholds
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(
@@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)

View File

@@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = None
services_requires_scan: Optional[bool] = None
"""
If True, services must be scanned to update the health state. If False, true state is always shown.
"""
applications_requires_scan: Optional[bool] = None
"""
If True, applications must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
@@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"):
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
services_requires_scan: bool,
applications_requires_scan: bool,
include_users: bool,
) -> None:
"""
@@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"):
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
:param services_requires_scan: If True, services must be scanned to update the health state.
If False, the true state is always shown.
:type services_requires_scan: bool
:param applications_requires_scan: If True, applications must be scanned to update the health state.
If False, the true state is always shown.
:type applications_requires_scan: bool
:param include_users: If True, report user session information.
:type include_users: bool
"""
@@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
@@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"):
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
self.applications.append(
ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan)
)
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
@@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"):
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
self.nics.append(
NICObservation(
where=None,
include_nmne=include_nmne,
monitored_traffic=monitored_traffic,
)
)
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
@@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
folder_config.thresholds = config.thresholds
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
nic_config.thresholds = config.thresholds
for service_config in config.services:
service_config.services_requires_scan = config.services_requires_scan
for application_config in config.applications:
application_config.applications_requires_scan = config.applications_requires_scan
application_config.thresholds = config.thresholds
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
@@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"):
count = 1
while len(nics) < config.num_nics:
nic_config = NICObservation.ConfigSchema(
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
nic_num=count,
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
count += 1
@@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
services_requires_scan=config.services_requires_scan,
applications_requires_scan=config.applications_requires_scan,
include_users=config.include_users,
)

View File

@@ -1,13 +1,14 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, List, Optional
from typing import ClassVar, Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import NMNEConfig
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port
@@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port
class NICObservation(AbstractObservation, discriminator="network-interface"):
"""Status information about a network interface within the simulation environment."""
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
"A Boolean specifying whether malicious network events should be captured."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
@@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None
"""A dict containing which traffic types are to be included in the observation."""
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
def __init__(
self,
where: WhereType,
include_nmne: bool,
monitored_traffic: Optional[Dict] = None,
thresholds: Dict = {},
) -> None:
"""
Initialise a network interface observation instance.
@@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
self.nmne_inbound_last_step: int = 0
self.nmne_outbound_last_step: int = 0
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
if thresholds.get("nmne") is None:
self.low_nmne_threshold = 0
self.med_nmne_threshold = 5
self.high_nmne_threshold = 10
else:
self._set_nmne_threshold(
thresholds=[
thresholds.get("nmne")["low"],
thresholds.get("nmne")["medium"],
thresholds.get("nmne")["high"],
]
)
self.monitored_traffic = monitored_traffic
if self.monitored_traffic:
@@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
bandwidth_utilisation = traffic_value / nic_max_bandwidth
return int(bandwidth_utilisation * 9) + 1
def _set_nmne_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the NMNE threshold.
:param: thresholds: The NMNE threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=thresholds,
threshold_identifier="nmne",
):
self.low_nmne_threshold = thresholds[0]
self.med_nmne_threshold = thresholds[1]
self.high_nmne_threshold = thresholds[2]
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
"""
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
if nic_state is NOT_PRESENT_IN_STATE or self.where is None:
return self.default_observation
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
@@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
for port in self.monitored_traffic[protocol]:
obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0}
if self.include_nmne:
if self.capture_nmne and self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
@@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
where=parent_where + ["NICs", config.nic_num],
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)

View File

@@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
"""If True, the folder must be scanned to update the health state. If False, the true state is always shown."""
services_requires_scan: bool = True
"""If True, the services must be scanned to update the health state.
If False, the true state is always shown."""
applications_requires_scan: bool = True
"""If True, the applications must be scanned to update the health state.
If False, the true state is always shown."""
include_users: Optional[bool] = True
"""If True, report user session information."""
num_ports: Optional[int] = None
@@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.services_requires_scan is None:
host_config.services_requires_scan = config.services_requires_scan
if host_config.applications_requires_scan is None:
host_config.applications_requires_scan = config.applications_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
if not host_config.thresholds:
host_config.thresholds = config.thresholds
for router_config in config.routers:
if router_config.num_ports is None:
@@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
if not router_config.thresholds:
router_config.thresholds = config.thresholds
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
if not firewall_config.thresholds:
firewall_config.thresholds = config.thresholds
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"):
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
obs_instance = obs_class.from_config(
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
)
instances[component.label] = obs_instance
return cls(components=instances)
@@ -242,8 +244,5 @@ class ObservationManager(BaseModel):
"""
if config is None:
return cls(NullObservation())
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
obs_manager = cls(observation)
obs_manager = cls(config=config)
return obs_manager

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
thresholds: Optional[Dict] = {}
"""A dict containing the observation thresholds."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
@@ -69,3 +72,34 @@ class AbstractObservation(ABC):
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
"""
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
Pass in the thresholds from low to high e.g.
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
Throws an error if the threshold is not valid
:param: thresholds: List of thresholds in ascending order.
:type: List[int]
:param: threshold_identifier: The name of the threshold option.
:type: Optional[str]
:returns: bool
"""
if thresholds is None or len(thresholds) < 2:
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
for idx in range(1, len(thresholds)):
if not isinstance(thresholds[idx], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if not isinstance(thresholds[idx - 1], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if thresholds[idx] <= thresholds[idx - 1]:
raise Exception(
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
f"is greater than or equal to ({thresholds[idx]}.)"
)
return True

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
service_name: str
"""Name of the service, used for querying simulation state dictionary"""
def __init__(self, where: WhereType) -> None:
services_requires_scan: Optional[bool] = None
"""If True, services must be scanned to update the health state. If False, true state is always shown."""
def __init__(self, where: WhereType, services_requires_scan: bool) -> None:
"""
Initialise a service observation instance.
@@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
:type where: WhereType
"""
self.where = where
self.services_requires_scan = services_requires_scan
self.default_observation = {"operating_status": 0, "health_status": 0}
def observe(self, state: Dict) -> ObsType:
@@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
return self.default_observation
return {
"operating_status": service_state["operating_state"],
"health_status": service_state["health_state_visible"],
"health_status": service_state["health_state_visible"]
if self.services_requires_scan
else service_state["health_state_actual"],
}
@property
@@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
:return: Constructed service observation instance.
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config.service_name])
return cls(
where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan
)
class ApplicationObservation(AbstractObservation, discriminator="application"):
@@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
def __init__(self, where: WhereType) -> None:
applications_requires_scan: Optional[bool] = None
"""
If True, applications must be scanned to update the health state. If False, true state is always shown.
"""
def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None:
"""
Initialise an application observation instance.
@@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
:type where: WhereType
"""
self.where = where
self.applications_requires_scan = applications_requires_scan
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("app_executions") is None:
self.low_app_execution_threshold = 0
self.med_app_execution_threshold = 5
self.high_app_execution_threshold = 10
else:
self._set_application_execution_thresholds(
thresholds=[
thresholds.get("app_executions")["low"],
thresholds.get("app_executions")["medium"],
thresholds.get("app_executions")["high"],
]
)
def _set_application_execution_thresholds(self, thresholds: List[int]):
"""
Method that validates and then sets the application execution threshold.
:param: thresholds: The application execution threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="app_executions",
):
self.low_app_execution_threshold = thresholds[0]
self.med_app_execution_threshold = thresholds[1]
self.high_app_execution_threshold = thresholds[2]
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
Represent number of application executions as a categorical variable.
:param num_access: Number of file accesses.
:param num_access: Number of application executions.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
if num_executions > self.high_app_execution_threshold:
return 3
elif num_executions > self.med_threshold:
elif num_executions > self.med_app_execution_threshold:
return 2
elif num_executions > self.low_threshold:
elif num_executions > self.low_app_execution_threshold:
return 1
return 0
@@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
return self.default_observation
return {
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"health_status": application_state["health_state_visible"]
if self.applications_requires_scan
else application_state["health_state_actual"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["applications", config.application_name])
return cls(
where=parent_where + ["applications", config.application_name],
applications_requires_scan=config.applications_requires_scan,
thresholds=config.thresholds,
)

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.observations import NICObservation
from primaite.game.agent.rewards import SharedReward
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
@@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
SERVICE_TYPES_MAPPING = {
"DNSClient": DNSClient,
"DNSServer": DNSServer,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
"Terminal": Terminal,
"dns-client": DNSClient,
"dns-server": DNSServer,
"database-service": DatabaseService,
"web-server": WebServer,
"ftp-client": FTPClient,
"ftp-server": FTPServer,
"ntp-client": NTPClient,
"ntp-server": NTPServer,
"terminal": Terminal,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
@@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel):
seed: int = None
"""Random number seed for RNGs."""
generate_seed_value: bool = False
"""Internally generated seed value."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[Port]
@@ -175,6 +178,7 @@ class PrimaiteGame:
parameters=parameters,
request=request,
response=response,
observation=obs,
)
def pre_timestep(self) -> None:
@@ -263,6 +267,7 @@ class PrimaiteGame:
node_sets_cfg = network_config.get("node_sets", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
@@ -293,6 +298,7 @@ class PrimaiteGame:
if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
@@ -407,6 +413,7 @@ class PrimaiteGame:
agents_cfg = cfg.get("agents", [])
for agent_cfg in agents_cfg:
agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds}
new_agent = AbstractAgent.from_config(agent_cfg)
game.agents[agent_cfg["ref"]] = new_agent
if isinstance(new_agent, ProxyAgent):

View File

@@ -50,40 +50,22 @@
"custom_c2_agent = \"\"\"\n",
" - ref: CustomC2Agent\n",
" team: RED\n",
" type: ProxyAgent\n",
" type: proxy-a.gent\n",
"\n",
" action_space:\n",
" options:\n",
" nodes:\n",
" - node_name: web_server\n",
" applications:\n",
" - application_name: C2Beacon\n",
" - node_name: client_1\n",
" applications:\n",
" - application_name: C2Server\n",
" max_folders_per_node: 1\n",
" max_files_per_folder: 1\n",
" max_services_per_node: 2\n",
" max_nics_per_node: 8\n",
" max_acl_rules: 10\n",
" ip_list:\n",
" - 192.168.1.21\n",
" - 192.168.1.14\n",
" wildcard_list:\n",
" - 0.0.0.1\n",
" action_map:\n",
" 0:\n",
" action: do_nothing\n",
" options: {}\n",
" 1:\n",
" action: node_application_install\n",
" action: node-application-install\n",
" options:\n",
" node_id: 0\n",
" application_name: C2Beacon\n",
" node_name: web_server\n",
" application_name: c2-beacon\n",
" 2:\n",
" action: configure_c2_beacon\n",
" action: configure-c2-beacon\n",
" options:\n",
" node_id: 0\n",
" node_name: web_server\n",
" config:\n",
" c2_server_ip_address: 192.168.10.21\n",
" keep_alive_frequency:\n",
@@ -92,10 +74,10 @@
" 3:\n",
" action: node_application_execute\n",
" options:\n",
" node_id: 0\n",
" application_id: 0\n",
" node_name: web_server\n",
" application_name: c2-beacon\n",
" 4:\n",
" action: c2_server_terminal_command\n",
" action: c2-server-terminal-command\n",
" options:\n",
" node_id: 1\n",
" ip_address:\n",
@@ -111,14 +93,14 @@
" 5:\n",
" action: c2-server-ransomware-configure\n",
" options:\n",
" node_id: 1\n",
" node_name: client_1\n",
" config:\n",
" server_ip_address: 192.168.1.14\n",
" payload: ENCRYPT\n",
" 6:\n",
" action: c2_server_data_exfiltrate\n",
" action: c2-server-data-exfiltrate\n",
" options:\n",
" node_id: 1\n",
" node_name: client_1\n",
" target_file_name: \"database.db\"\n",
" target_folder_name: \"database\"\n",
" exfiltration_folder_name: \"spoils\"\n",
@@ -128,31 +110,27 @@
" password: admin\n",
"\n",
" 7:\n",
" action: c2_server_ransomware_launch\n",
" action: c2-server-ransomware-launch\n",
" options:\n",
" node_id: 1\n",
" node_name: client_1\n",
" 8:\n",
" action: configure_c2_beacon\n",
" action: configure-c2-beacon\n",
" options:\n",
" node_id: 0\n",
" node_name: web_server\n",
" config:\n",
" c2_server_ip_address: 192.168.10.21\n",
" keep_alive_frequency: 10\n",
" masquerade_protocol: TCP\n",
" masquerade_port: DNS\n",
" 9:\n",
" action: configure_c2_beacon\n",
" action: configure-c2-beacon\n",
" options:\n",
" node_id: 0\n",
" node_name: web_server\n",
" config:\n",
" c2_server_ip_address: 192.168.10.22\n",
" keep_alive_frequency:\n",
" masquerade_protocol:\n",
" masquerade_port:\n",
"\n",
" reward_function:\n",
" reward_components:\n",
" - type: DUMMY\n",
"\"\"\"\n",
"c2_agent_yaml = yaml.safe_load(custom_c2_agent)"
]
@@ -225,7 +203,7 @@
" nodes: # Node List\n",
" - node_name: web_server\n",
" applications: \n",
" - application_name: C2Beacon\n",
" - application_name: c2-beacon\n",
" ...\n",
" ...\n",
" action_map:\n",
@@ -233,7 +211,7 @@
" action: node_application_install \n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" application_name: C2Beacon\n",
" application_name: c2-beacon\n",
"```"
]
},
@@ -268,7 +246,7 @@
" action_map:\n",
" ...\n",
" 2:\n",
" action: configure_c2_beacon\n",
" action: configure-c2-beacon\n",
" options:\n",
" node_id: 0 # Node Index\n",
" config: # Further information about these config options can be found at the bottom of this notebook.\n",
@@ -286,7 +264,7 @@
"outputs": [],
"source": [
"env.step(2)\n",
"c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
"c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
"web_server.software_manager.show()\n",
"c2_beacon.show()"
]
@@ -307,13 +285,13 @@
" nodes: # Node List\n",
" - node_name: web_server\n",
" applications: \n",
" - application_name: C2Beacon\n",
" - application_name: c2-beacon\n",
" ...\n",
" ...\n",
" action_map:\n",
" ...\n",
" 3:\n",
" action: node_application_execute\n",
" action: node-application-execute\n",
" options:\n",
" node_id: 0\n",
" application_id: 0\n",
@@ -374,11 +352,11 @@
" ...\n",
" - node_name: client_1\n",
" applications: \n",
" - application_name: C2Server\n",
" - application_name: c2-server\n",
" ...\n",
" action_map:\n",
" 4:\n",
" action: C2_SERVER_TERMINAL_COMMAND\n",
" action: c2-server-terminal-command\n",
" options:\n",
" node_id: 1\n",
" ip_address:\n",
@@ -431,7 +409,7 @@
" ...\n",
" - node_name: client_1\n",
" applications: \n",
" - application_name: C2Server\n",
" - application_name: c2-server\n",
" ...\n",
" action_map:\n",
" 5:\n",
@@ -459,7 +437,7 @@
"metadata": {},
"outputs": [],
"source": [
"ransomware_script: RansomwareScript = web_server.software_manager.software[\"RansomwareScript\"]\n",
"ransomware_script: RansomwareScript = web_server.software_manager.software[\"ransomware-script\"]\n",
"web_server.software_manager.show()\n",
"ransomware_script.show()"
]
@@ -483,11 +461,11 @@
" ...\n",
" - node_name: client_1\n",
" applications: \n",
" - application_name: C2Server\n",
" - application_name: c2-server\n",
" ...\n",
" action_map:\n",
" 6:\n",
" action: c2_server_data_exfiltrate\n",
" action: c2-server-data-exfiltrate\n",
" options:\n",
" node_id: 1\n",
" target_file_name: \"database.db\"\n",
@@ -549,11 +527,11 @@
" ...\n",
" - node_name: client_1\n",
" applications: \n",
" - application_name: C2Server\n",
" - application_name: c2-server\n",
" ...\n",
" action_map:\n",
" 7:\n",
" action: c2_server_ransomware_launch\n",
" action: c2-server-ransomware-launch\n",
" options:\n",
" node_id: 1\n",
"```\n"
@@ -598,20 +576,20 @@
"custom_blue_agent_yaml = \"\"\"\n",
" - ref: defender\n",
" team: BLUE\n",
" type: ProxyAgent\n",
" type: proxy-agent\n",
"\n",
" observation_space:\n",
" type: CUSTOM\n",
" type: custom\n",
" options:\n",
" components:\n",
" - type: NODES\n",
" - type: nodes\n",
" label: NODES\n",
" options:\n",
" hosts:\n",
" - hostname: web_server\n",
" applications:\n",
" - application_name: C2Beacon\n",
" - application_name: RansomwareScript\n",
" - application_name: c2-beacon\n",
" - application_name: ransomware-script\n",
" folders:\n",
" - folder_name: exfiltration_folder\n",
" files:\n",
@@ -661,7 +639,7 @@
" - UDP\n",
" num_rules: 10\n",
"\n",
" - type: LINKS\n",
" - type: links\n",
" label: LINKS\n",
" options:\n",
" link_references:\n",
@@ -675,7 +653,7 @@
" - switch_2:eth-1<->client_1:eth-1\n",
" - switch_2:eth-2<->client_2:eth-1\n",
" - switch_2:eth-7<->security_suite:eth-2\n",
" - type: \"NONE\"\n",
" - type: \"none\"\n",
" label: ICS\n",
" options: {}\n",
"\n",
@@ -685,16 +663,16 @@
" action: do_nothing\n",
" options: {}\n",
" 1:\n",
" action: node_application_remove\n",
" action: node-application-remove\n",
" options:\n",
" node_id: 0\n",
" node_name: web-server\n",
" application_name: C2Beacon\n",
" 2:\n",
" action: node_shutdown\n",
" action: node-shutdown\n",
" options:\n",
" node_id: 0\n",
" node_name: web-server\n",
" 3:\n",
" action: router_acl_add_rule\n",
" action: router-acl-add-rule\n",
" options:\n",
" target_router: router_1\n",
" position: 1\n",
@@ -707,36 +685,6 @@
" source_wildcard_id: 0\n",
" dest_wildcard_id: 0\n",
"\n",
"\n",
" options:\n",
" nodes:\n",
" - node_name: web_server\n",
" applications:\n",
" - application_name: C2Beacon\n",
"\n",
" - node_name: database_server\n",
" folders:\n",
" - folder_name: database\n",
" files:\n",
" - file_name: database.db\n",
" services:\n",
" - service_name: DatabaseService\n",
" - node_name: router_1\n",
"\n",
" max_folders_per_node: 2\n",
" max_files_per_folder: 2\n",
" max_services_per_node: 2\n",
" max_nics_per_node: 8\n",
" max_acl_rules: 10\n",
" ip_list:\n",
" - 192.168.10.21\n",
" - 192.168.1.12\n",
" wildcard_list:\n",
" - 0.0.0.1\n",
" reward_function:\n",
" reward_components:\n",
" - type: DUMMY\n",
"\n",
" agent_settings:\n",
" flatten_obs: False\n",
"\"\"\"\n",
@@ -875,7 +823,7 @@
"outputs": [],
"source": [
"# Installing RansomwareScript via C2 Terminal Commands\n",
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n",
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n",
" \"username\": \"admin\",\n",
" \"password\": \"admin\"}\n",
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)\n"
@@ -1034,11 +982,11 @@
" web_server: Server = given_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n",
"\n",
" client_1.software_manager.install(C2Server)\n",
" c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
" c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
" c2_server.run()\n",
"\n",
" web_server.software_manager.install(C2Beacon)\n",
" c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
" c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
" c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n",
" c2_beacon.establish()\n",
"\n",
@@ -1132,11 +1080,11 @@
"outputs": [],
"source": [
"# Attempting to install the C2 RansomwareScript\n",
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n",
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n",
" \"username\": \"admin\",\n",
" \"password\": \"admin\"}\n",
"\n",
"c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
"c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)"
]
},
@@ -1220,11 +1168,11 @@
"outputs": [],
"source": [
"# Attempting to install the C2 RansomwareScript\n",
"ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n",
"ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"ransomware-script\"],\n",
" \"username\": \"admin\",\n",
" \"password\": \"admin\"}\n",
"\n",
"c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
"c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)"
]
},
@@ -1345,7 +1293,7 @@
"metadata": {},
"outputs": [],
"source": [
"database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database_server\")\n",
"database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database-server\")\n",
"database_server.software_manager.file_system.show(full=True)"
]
},
@@ -1391,7 +1339,7 @@
"\n",
"``` YAML\n",
"...\n",
" action: configure_c2_beacon\n",
" action: configure-c2-beacon\n",
" options:\n",
" node_id: 0\n",
" config:\n",
@@ -1446,16 +1394,16 @@
"source": [
"web_server: Server = c2_config_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n",
"web_server.software_manager.install(C2Beacon)\n",
"c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
"c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
"\n",
"client_1: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
"client_1.software_manager.install(C2Server)\n",
"c2_server_1: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
"c2_server_1: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
"c2_server_1.run()\n",
"\n",
"client_2: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_2\")\n",
"client_2.software_manager.install(C2Server)\n",
"c2_server_2: C2Server = client_2.software_manager.software[\"C2Server\"]\n",
"c2_server_2: C2Server = client_2.software_manager.software[\"c2-server\"]\n",
"c2_server_2.run()"
]
},
@@ -1759,6 +1707,16 @@
"\n",
"display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"env.game.agents[\"CustomC2Agent\"].show_history()"
]
}
],
"metadata": {

View File

@@ -47,7 +47,7 @@
"source": [
"def make_cfg_have_flat_obs(cfg):\n",
" for agent in cfg['agents']:\n",
" if agent['type'] == \"ProxyAgent\":\n",
" if agent['type'] == \"proxy-agent\":\n",
" agent['agent_settings']['flatten_obs'] = False"
]
},
@@ -76,9 +76,9 @@
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'do_nothing':\n",
" if red_action == 'do-nothing':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'node_application_execute':\n",
" elif red_action == 'node-application-execute':\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
@@ -147,36 +147,14 @@
"```yaml\n",
" - ref: data_manipulation_attacker # name of agent\n",
" team: RED # not used, just for human reference\n",
" type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n",
" type: red-database-corrupting-agent # type of agent - this lets primaite know which agent class to use\n",
"\n",
" # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n",
" observation_space:\n",
" type: UC2RedObservation\n",
" type: uc2-red-observation # TODO: what\n",
" options:\n",
" nodes: {}\n",
"\n",
" action_space:\n",
" \n",
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
" options:\n",
" nodes:\n",
" - node_name: client_1 # The network should have a node called client_1\n",
" applications:\n",
" - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n",
" - node_name: client_2 # The network should have a node called client_2\n",
" applications:\n",
" - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n",
"\n",
" # not important\n",
" max_folders_per_node: 1\n",
" max_files_per_folder: 1\n",
" max_services_per_node: 1\n",
"\n",
" # red agent does not need a reward function\n",
" reward_function:\n",
" reward_components:\n",
" - type: DUMMY\n",
"\n",
" # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n",
" agent_settings:\n",
" start_settings:\n",
@@ -211,15 +189,13 @@
" \n",
" # \n",
" applications:\n",
" - ref: data_manipulation_bot\n",
" type: DataManipulationBot\n",
" - type: data-manipulation-bot\n",
" options:\n",
" port_scan_p_of_success: 0.8 # Probability that port scan is successful\n",
" data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n",
" payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n",
" server_ip: 192.168.1.14 # IP address of server hosting the database\n",
" - ref: client_1_database_client\n",
" type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n",
" - type: database-client # Database client must be installed in order for DataManipulationBot to function\n",
" options:\n",
" db_server_ip: 192.168.1.14 # IP address of server hosting the database\n",
"```"
@@ -354,19 +330,16 @@
"# Make attack always succeed.\n",
"change = yaml.safe_load(\"\"\"\n",
" applications:\n",
" - ref: data_manipulation_bot\n",
" type: DataManipulationBot\n",
" - type: data-manipulation-bot\n",
" options:\n",
" port_scan_p_of_success: 1.0\n",
" data_manipulation_p_of_success: 1.0\n",
" payload: \"DELETE\"\n",
" server_ip: 192.168.1.14\n",
" - ref: client_1_web_browser\n",
" type: WebBrowser\n",
" - type: web-browser\n",
" options:\n",
" target_url: http://arcd.com/users/\n",
" - ref: client_1_database_client\n",
" type: DatabaseClient\n",
" - type: database-client\n",
" options:\n",
" db_server_ip: 192.168.1.14\n",
"\"\"\")\n",
@@ -399,19 +372,16 @@
"# Make attack always fail.\n",
"change = yaml.safe_load(\"\"\"\n",
" applications:\n",
" - ref: data_manipulation_bot\n",
" type: DataManipulationBot\n",
" - type: data-manipulation-bot\n",
" options:\n",
" port_scan_p_of_success: 0.0\n",
" data_manipulation_p_of_success: 0.0\n",
" payload: \"DELETE\"\n",
" server_ip: 192.168.1.14\n",
" - ref: client_1_web_browser\n",
" type: WebBrowser\n",
" - type: web-browser\n",
" options:\n",
" target_url: http://arcd.com/users/\n",
" - ref: client_1_database_client\n",
" type: DatabaseClient\n",
" - type: database-client\n",
" options:\n",
" db_server_ip: 192.168.1.14\n",
"\"\"\")\n",

View File

@@ -684,6 +684,15 @@
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.game.agents[\"data_manipulation_attacker\"].show_history()"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -717,7 +726,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -153,6 +153,49 @@
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Viewing Agent history\n",
"\n",
"It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)\n",
"\n",
"# Run the training session to generate some resultant data.\n",
"for i in range(100):\n",
" env.step(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"attacker = env.game.agents[\"data_manipulation_attacker\"]\n",
"\n",
"attacker.show_history()"
]
}
],
"metadata": {
@@ -171,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,479 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PrimAITE Developer mode\n",
"\n",
"PrimAITE has built in developer tools.\n",
"\n",
"The dev-mode is designed to help make the development of PrimAITE easier.\n",
"\n",
"`NOTE: For the purposes of the notebook, the commands are preceeded by \"!\". When running the commands, run it without the \"!\".`\n",
"\n",
"To display the available dev-mode options, run the command below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode --help"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the current PRIMAITE_CONFIG to restore after the notebook runs\n",
"\n",
"from primaite import PRIMAITE_CONFIG\n",
"\n",
"temp_config = PRIMAITE_CONFIG.copy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dev mode options"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### enable\n",
"\n",
"Enables the dev mode for PrimAITE.\n",
"\n",
"This will enable the developer mode for PrimAITE.\n",
"\n",
"By default, when developer mode is enabled, session logs will be generated in the PRIMAITE_ROOT/sessions folder unless configured to be generated in another location."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode enable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### disable\n",
"\n",
"Disables the dev mode for PrimAITE.\n",
"\n",
"This will disable the developer mode for PrimAITE."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode disable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show\n",
"\n",
"Shows if PrimAITE is running in dev mode or production mode.\n",
"\n",
"The command will also show the developer mode configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode show"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### config\n",
"\n",
"Configure the PrimAITE developer mode"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --help"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### path\n",
"\n",
"Set the path where generated session files will be output.\n",
"\n",
"By default, this value will be in PRIMAITE_ROOT/sessions.\n",
"\n",
"To reset the path to default, run:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config path -root\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config path --default"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --sys-log-level or -slevel\n",
"\n",
"Set the system log level.\n",
"\n",
"This will override the system log level in configurations and will make PrimAITE include the set log level and above.\n",
"\n",
"Available options are:\n",
"- `DEBUG`\n",
"- `INFO`\n",
"- `WARNING`\n",
"- `ERROR`\n",
"- `CRITICAL`\n",
"\n",
"Default value is `DEBUG`\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --sys-log-level DEBUG\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -slevel DEBUG"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --agent-log-level or -alevel\n",
"\n",
"Set the agent log level.\n",
"\n",
"This will override the agent log level in configurations and will make PrimAITE include the set log level and above.\n",
"\n",
"Available options are:\n",
"- `DEBUG`\n",
"- `INFO`\n",
"- `WARNING`\n",
"- `ERROR`\n",
"- `CRITICAL`\n",
"\n",
"Default value is `DEBUG`\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --agent-log-level DEBUG\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -alevel DEBUG"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --output-sys-logs or -sys\n",
"\n",
"If enabled, developer mode will output system logs.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --output-sys-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -sys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To disable outputting sys logs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --no-sys-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -nsys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --output-agent-logs or -agent\n",
"\n",
"If enabled, developer mode will output agent action logs.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --output-agent-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To disable outputting agent action logs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --no-agent-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -nagent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --output-pcap-logs or -pcap\n",
"\n",
"If enabled, developer mode will output PCAP logs.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --output-pcap-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -pcap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To disable outputting PCAP logs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --no-pcap-logs\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -npcap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### --output-to-terminal or -t\n",
"\n",
"If enabled, developer mode will output logs to the terminal.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --output-to-terminal\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -t"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To disable terminal outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config --no-terminal\n",
"\n",
"# or\n",
"\n",
"!primaite dev-mode config -nt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combining commands\n",
"\n",
"It is possible to combine commands to set the configuration.\n",
"\n",
"This saves having to enter multiple commands and allows for a much more efficient setting of PrimAITE developer mode configurations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Example of setting system log level and enabling the system logging:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config -slevel WARNING -sys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another example where the system log and agent action log levels are set and enabled and should be printed to terminal:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite dev-mode config -slevel ERROR -sys -alevel ERROR -agent -t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Restore PRIMAITE_CONFIG\n",
"from primaite.utils.cli.primaite_config_utils import update_primaite_application_config\n",
"\n",
"\n",
"global PRIMAITE_CONFIG\n",
"PRIMAITE_CONFIG[\"developer_mode\"] = temp_config[\"developer_mode\"]\n",
"update_primaite_application_config(config=PRIMAITE_CONFIG)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -114,7 +114,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(f\"DNS Client state: {client.software_manager.software.get('DNSClient').operating_state.name}\")"
"print(f\"DNS Client state: {client.software_manager.software.get('dns-client').operating_state.name}\")"
]
},
{

View File

@@ -9,6 +9,13 @@
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulation Layer Implementation."
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -67,9 +74,9 @@
"source": [
"network: Network = basic_network()\n",
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
"terminal_a: Terminal = computer_a.software_manager.software.get(\"terminal\")\n",
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
"terminal_b: Terminal = computer_b.software_manager.software.get(\"terminal\")"
]
},
{
@@ -121,7 +128,7 @@
"metadata": {},
"outputs": [],
"source": [
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"ransomware-script\"])"
]
},
{
@@ -169,6 +176,22 @@
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(terminal_a.last_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -207,6 +230,263 @@
"source": [
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Game Layer Implementation\n",
"\n",
"This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n",
"\n",
"The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n",
"\n",
"\n",
"| Game Layer Action | Simulation Layer |\n",
"|-----------------------------------|--------------------------|\n",
"| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n",
"| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n",
"| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Game Layer Setup\n",
"\n",
"Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n",
"\n",
"If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"custom_terminal_agent = \"\"\"\n",
" - ref: CustomC2Agent\n",
" team: RED\n",
" type: proxy-agent\n",
" observation_space: null\n",
" action_space:\n",
" options:\n",
" nodes:\n",
" - node_name: client_1\n",
" max_folders_per_node: 1\n",
" max_files_per_folder: 1\n",
" max_services_per_node: 2\n",
" max_nics_per_node: 8\n",
" max_acl_rules: 10\n",
" ip_list:\n",
" - 192.168.1.21\n",
" - 192.168.1.14\n",
" wildcard_list:\n",
" - 0.0.0.1\n",
" action_map:\n",
" 0:\n",
" action: do-nothing\n",
" options: {}\n",
" 1:\n",
" action: node-send-local-command\n",
" options:\n",
" node_name: client_1\n",
" username: admin\n",
" password: admin\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - dog.png\n",
" - False\n",
" 2:\n",
" action: node-session-remote-login\n",
" options:\n",
" node_name: client_1\n",
" username: admin\n",
" password: admin\n",
" remote_ip: 192.168.10.22\n",
" 3:\n",
" action: node-send-remote-command\n",
" options:\n",
" node_name: client_1\n",
" remote_ip: 192.168.10.22\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - cat.png\n",
" - False\n",
"\"\"\"\n",
"custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path()) as f:\n",
" cfg = yaml.safe_load(f)\n",
" # removing all agents & adding the custom agent.\n",
" cfg['agents'] = {}\n",
" cfg['agents'] = custom_terminal_agent_yaml\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)\n",
"\n",
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
"client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-send-local-command`` \n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-send-local-command\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 1:\n",
" action: node-send-local-command\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" username: admin\n",
" password: admin\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - dog.png\n",
" - False\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(1)\n",
"client_1.file_system.show(full=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-session-remote-login`` \n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-session-remote-login\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 2:\n",
" action: node-session-remote-login\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" username: admin\n",
" password: admin\n",
" remote_ip: 192.168.10.22 # client_2's ip address.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(2)\n",
"client_2.session_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-send-remote-command``\n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-send-remote-command\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 1:\n",
" action: node-send-remote-command\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" remote_ip: 192.168.10.22\n",
" commands:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - cat.png\n",
" - False\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(3)\n",
"client_2.file_system.show(full=True)"
]
}
],
"metadata": {

View File

@@ -26,14 +26,26 @@ except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
"""
Set random number generators.
If seed is None or -1 and generate_seed_value is True randomly generate a
seed value.
If seed is > -1 and generate_seed_value is True ignore the latter and use
the provide seed value.
:param seed: int
:param generate_seed_value: bool
:return: None or the int representing the seed used.
"""
if seed is None or seed == -1:
return None
if generate_seed_value:
rng = np.random.default_rng()
# 2**32-1 is highest value for python RNG seed.
seed = int(rng.integers(low=0, high=2**32 - 1))
else:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
return seed
def log_seed_value(seed: int):
"""Log the selected seed value to file."""
path = SIM_OUTPUT.path / "seed.log"
with open(path, "w") as file:
file.write(f"Seed value = {seed}")
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
self.seed = set_random_seed(self.seed, self.generate_seed_value)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
log_seed_value(self.seed)
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
set_random_seed(seed, self.generate_seed_value)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -864,7 +864,21 @@ class UserManager(Service, discriminator="user-manager"):
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
# todo add doc about request schemas
rm.add_request(
"add_user",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.add_user(username=request[0], password=request[1], is_admin=request[2])
)
),
)
rm.add_request(
"disable_user",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0]))
),
)
rm.add_request(
"change_password",
RequestType(
@@ -1572,7 +1586,7 @@ class Node(SimComponent, ABC):
operating_state: Any = None
users: Any = None # Temporary to appease "extra=forbid"
users: List[Dict] = [] # Temporary to appease "extra=forbid"
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
@@ -1638,6 +1652,8 @@ class Node(SimComponent, ABC):
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
for user in self.config.users:
self.user_manager.add_user(**user, bypass_can_perform_action=True)
@property
def user_manager(self) -> Optional[UserManager]:
@@ -1769,7 +1785,7 @@ class Node(SimComponent, ABC):
"""
application_name = request[0]
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
self.sys_log.info(f"Can't install {application_name}. It's already installed.")
return RequestResponse(status="success", data={"reason": "already installed"})
application_class = Application._registry[application_name]
self.software_manager.install(application_class)

View File

@@ -2,11 +2,12 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Literal, Optional
from typing import Any, ClassVar, Dict, List, Literal, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
@@ -313,7 +314,7 @@ class HostNode(Node, discriminator="host-node"):
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"HostARP": HostARP,
"host-arp": HostARP,
"icmp": ICMP,
"dns-client": DNSClient,
"ntp-client": NTPClient,
@@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"):
ip_address: IPV4Address
services: Any = None # temporarily unset to appease extra="forbid"
applications: Any = None # temporarily unset to appease extra="forbid"
folders: Any = None # temporarily unset to appease extra="forbid"
folders: List[Dict] = {} # temporarily unset to appease extra="forbid"
network_interfaces: Any = None # temporarily unset to appease extra="forbid"
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
@@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
for folder in self.config.folders:
# handle empty foler defined by just a string
self.file_system.create_folder(folder["folder_name"])
for file in folder.get("files", []):
self.file_system.create_file(
folder_name=folder["folder_name"],
file_name=file["file_name"],
size=file.get("size", 0),
file_type=FileType[file.get("type", "UNKNOWN").upper()],
)
@property
def nmap(self) -> Optional[NMAP]:
"""

View File

@@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"):
Example:
>>> from primaite.simulator.network.transmission.network_layer import IPProtocol
>>> from primaite.simulator.network.transmission.transport_layer import Port
>>> from primaite.utils.validation.port import Port
>>> firewall = Firewall(hostname="Firewall1")
>>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0")
>>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0")

View File

@@ -467,6 +467,7 @@ class AccessControlList(SimComponent):
"""Check if a packet with the given properties is permitted through the ACL."""
permitted = False
rule: ACLRule = None
for _rule in self._acl:
if not _rule:
continue
@@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"):
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
"user-session-manager": UserSessionManager,
"user-manager": UserManager,
"terminal": Terminal,
}
network_interfaces: Dict[str, RouterInterface] = {}
@@ -1385,6 +1386,12 @@ class Router(NetworkNode, discriminator="router"):
return False
def subject_to_acl(self, frame: Frame) -> bool:
"""Check that frame is subject to ACL rules."""
if frame.ip.protocol == "udp" and frame.is_arp:
return False
return True
def receive_frame(self, frame: Frame, from_network_interface: RouterInterface):
"""
Processes an incoming frame received on one of the router's interfaces.
@@ -1398,8 +1405,12 @@ class Router(NetworkNode, discriminator="router"):
if self.operating_state != NodeOperatingState.ON:
return
# Check if it's permitted
permitted, rule = self.acl.is_permitted(frame)
if self.subject_to_acl(frame=frame):
# Check if it's permitted
permitted, rule = self.acl.is_permitted(frame)
else:
permitted = True
rule = None
if not permitted:
at_port = self._get_port_of_nic(from_network_interface)

View File

@@ -163,7 +163,7 @@ class Frame(BaseModel):
"""
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
This is determined by checking if the destination and source port of the UDP header is equal to the ARP port.
:return: True if the Frame is an ARP packet, otherwise False.
"""

View File

@@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"):
:param markdown: If True, format the output as Markdown. Otherwise, use plain text.
"""
table = PrettyTable(["IP Address", "MAC Address", "Via"])
table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
@@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"):
str(ip),
arp.mac_address,
self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address,
self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num,
]
)
print(table)

View File

@@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"):
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""
_last_response: Optional[RequestResponse] = None
"""Last response received from RequestManager, for returning remote RequestResponse."""
def __init__(self, **kwargs):
kwargs["name"] = "terminal"
kwargs["port"] = PORT_LOOKUP["SSH"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@property
def last_response(self) -> Optional[RequestResponse]:
"""Public version of _last_response attribute."""
return self._last_response
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"):
return RequestResponse(status="failure", data={})
rm.add_request(
"node-session-remote-login",
"node_session_remote_login",
request_type=RequestType(func=_remote_login),
)
@@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"):
command: str = request[1]["command"]
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={},
)
remote_connection.execute(command)
return self.last_response if not None else RequestResponse(status="failure", data={})
return RequestResponse(
status="failure",
data={"reason": "Failed to execute command."},
)
rm.add_request(
"send_remote_command",
request_type=RequestType(func=remote_execute_request),
)
def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Executes a command using a local terminal session."""
command: str = request[2]["command"]
local_connection = self._process_local_login(username=request[0], password=request[1])
if local_connection:
outcome = local_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={"reason": outcome},
)
return RequestResponse(
status="success",
data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"},
)
rm.add_request(
"send_local_command",
request_type=RequestType(func=local_execute_request),
)
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
self._last_response = self.parent.apply_request(command)
return self._last_response
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
"""Find Remote Terminal Connection from a given IP."""
@@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"):
"""
source_ip = kwargs["frame"].ip.src_ip_address
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
self._last_response = None # Clear last response
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
@@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"):
session_id=session_id,
source_ip=source_ip,
)
self._last_response: RequestResponse = RequestResponse(
status="success", data={"reason": "Login Successful"}
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
@@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"):
payload.connection_uuid
)
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
self.execute(command)
self._last_response: RequestResponse = self.execute(command)
if self._last_response.status == "success":
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
else:
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED
payload: SSHPacket = SSHPacket(
payload=self._last_response,
transport_message=transport_message,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
)
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
return True
else:
self.sys_log.error(
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
)
elif (
payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
or SSHTransportMessage.SSH_MSG_SERVICE_FAILED
):
# Likely receiving command ack from remote.
self._last_response = payload.payload
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":

View File

@@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"):
:type: payload: HttpRequestPacket
"""
response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload)
try:
parsed_url = urlparse(payload.request_url)
path = parsed_url.path.strip("/")
if len(path) < 1:
parsed_url = urlparse(payload.request_url)
path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else ""
if len(path) < 1:
# query succeeded
response.status_code = HttpStatusCode.OK
if path.startswith("users"):
# get data from DatabaseServer
# get all users
if not self._establish_db_connection():
# unable to create a db connection
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
if self.db_connection.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
if path.startswith("users"):
# get data from DatabaseServer
# get all users
if not self.db_connection:
self._establish_db_connection()
if self.db_connection.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
def _establish_db_connection(self) -> None:
def _establish_db_connection(self) -> bool:
"""Establish a connection to db."""
# if active db connection, return true
if self.db_connection:
return True
# otherwise, try to create db connection
db_client = self.software_manager.software.get("database-client")
if db_client is None:
return False # database client not installed
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
return self.db_connection is not None
def send(
self,

View File

@@ -25,7 +25,19 @@ game:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 100
medium: 25
low: 5
file_access:
high: 10
medium: 5
low: 2
app_executions:
high: 5
medium: 3
low: 2
agents:
- ref: client_2_green_user
team: GREEN
@@ -64,10 +76,16 @@ agents:
options:
hosts:
- hostname: client_1
applications:
- application_name: WebBrowser
folders:
- folder_name: root
files:
- file_name: "test.txt"
- hostname: client_2
- hostname: client_3
num_services: 1
num_applications: 0
num_applications: 1
num_folders: 1
num_files: 1
num_nics: 2
@@ -182,6 +200,10 @@ simulation:
options:
ntp_server_ip: 192.168.1.10
- type: ntp-server
folders:
- folder_name: root
files:
- file_name: test.txt
- hostname: client_2
type: computer
ip_address: 192.168.10.22

View File

@@ -0,0 +1,226 @@
# Basic Switched network
#
# -------------- -------------- --------------
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
agent_log_level: INFO
save_agent_logs: true
write_agent_log_to_terminal: True
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_2_green_user
team: GREEN
type: periodic-agent
action_space:
action_map:
0:
action: do-nothing
options: {}
1:
action: node-application-execute
options:
node_id: 0
application_id: 0
agent_settings:
possible_start_nodes: [client_2,]
target_application: web-browser
start_step: 5
frequency: 4
variance: 3
- ref: defender
team: BLUE
type: proxy-agent
observation_space:
type: custom
options:
components:
- type: nodes
label: NODES
options:
hosts:
- hostname: client_1
- hostname: client_2
- hostname: client_3
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
include_nmne: false
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.10.21
- 192.168.10.22
- 192.168.10.23
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: links
label: LINKS
options:
link_references:
- switch_1:eth-1<->client_1:eth-1
- switch_1:eth-2<->client_2:eth-1
- type: none
label: ICS
options: {}
action_space:
action_map:
0:
action: do-nothing
options: {}
reward_function:
reward_components:
- type: database-file-integrity
weight: 0.5
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: web-server-404-penalty
weight: 0.5
options:
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
flatten_obs: true
simulation:
network:
nodes:
- type: switch
hostname: switch_1
num_ports: 8
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: ransomware-script
- type: web-browser
options:
target_url: http://arcd.com/users/
- type: database-client
options:
db_server_ip: 192.168.1.10
server_password: arcd
- type: data-manipulation-bot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.21
server_password: arcd
- type: dos-bot
options:
target_ip_address: 192.168.10.21
payload: SPOOF DATA
port_scan_p_of_success: 0.8
services:
- type: dns-client
options:
dns_server: 192.168.1.10
- type: dns-server
options:
domain_mapping:
arcd.com: 192.168.1.10
- type: database-service
options:
backup_server_ip: 192.168.1.10
- type: web-server
- type: ftp-server
options:
server_password: arcd
- type: ntp-client
options:
ntp_server_ip: 192.168.1.10
- type: ntp-server
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
folders:
- folder_name: empty_folder
- folder_name: downloads
files:
- file_name: "test.txt"
- file_name: "another_file.pwtwoti"
- folder_name: root
files:
- file_name: passwords
size: 663
type: TXT
# pre installed services and applications
- hostname: client_3
type: computer
ip_address: 192.168.10.23
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
start_up_duration: 0
shut_down_duration: 0
operating_state: "OFF"
# pre installed services and applications
links:
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
bandwidth: 200
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
bandwidth: 200

View File

@@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
@@ -24,3 +24,42 @@ def test_thresholds():
game = load_config(data_manipulation_config_path())
assert game.options.thresholds is not None
def test_nmne_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["nmne"] is not None
# get NIC observation
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
assert nic_obs.low_nmne_threshold == 5
assert nic_obs.med_nmne_threshold == 25
assert nic_obs.high_nmne_threshold == 100
def test_file_access_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["file_access"] is not None
# get file observation
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
assert file_obs.low_file_access_threshold == 2
assert file_obs.med_file_access_threshold == 5
assert file_obs.high_file_access_threshold == 10
def test_app_executions_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["app_executions"] is not None
# get application observation
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
assert app_obs.low_app_execution_threshold == 2
assert app_obs.med_app_execution_threshold == 3
assert app_obs.high_app_execution_threshold == 5

View File

@@ -0,0 +1,64 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Union
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_type import FileType
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
def test_node_file_system_from_config():
"""Test that the appropriate files are instantiated in nodes when loaded from config."""
game = load_config(BASIC_CONFIG)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
assert client_1.software_manager.software.get("database-service") # database service should be installed
assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist
assert client_1.software_manager.software.get("web-server") # web server should be installed
assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist
client_2 = game.simulation.network.get_node_by_hostname("client_2")
# database service should not be installed
assert client_2.software_manager.software.get("database-service") is None
# database files should not exist
assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None
# web server should not be installed
assert client_2.software_manager.software.get("web-server") is None
# web files should not exist
assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None
empty_folder = client_2.file_system.get_folder(folder_name="empty_folder")
assert empty_folder
assert len(empty_folder.files) == 0 # should have no files
password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt")
assert password_file # should exist
assert password_file.file_type is FileType.TXT
assert password_file.size == 663
downloads_folder = client_2.file_system.get_folder(folder_name="downloads")
assert downloads_folder # downloads folder should exist
test_txt = downloads_folder.get_file(file_name="test.txt")
assert test_txt # test.txt should exist
assert test_txt.file_type is FileType.TXT
unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti")
assert unknown_file_type # unknown_file_type should exist
assert unknown_file_type.file_type is FileType.UNKNOWN

View File

@@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)

View File

@@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -166,3 +164,55 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
assert server_1.file_system.get_folder("folder123") is None
assert server_1.file_system.get_file("folder123", "doggo.pdf") is None
def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
client_1_usm: UserManager = client_1.software_manager.software["user-manager"]
client_1_usm.add_user("user123", "password", is_admin=True)
action = (
"node-send-local-command",
{
"node_name": "client_1",
"username": "user123",
"password": "password",
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
},
)
agent.store_action(action)
game.step()
assert client_1.file_system.get_folder("folder123")
assert client_1.file_system.get_file("folder123", "doggo.pdf")
# Change password
action = (
"node-account-change-password",
{
"node_name": "client_1",
"username": "user123",
"current_password": "password",
"new_password": "different_password",
},
)
agent.store_action(action)
game.step()
action = (
"node-send-local-command",
{
"node_name": "client_1",
"username": "user123",
"password": "password",
"command": ["file_system", "create", "file", "folder123", "cat.pdf", False],
},
)
agent.store_action(action)
game.step()
assert client_1.file_system.get_file("folder123", "cat.pdf") is None
client_1.session_manager.show()

View File

@@ -0,0 +1,176 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
from primaite.utils.validation.port import Port, PORT_LOOKUP
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_user_account_add_user_action(game_and_agent_fixture):
"""Tests the add user account action."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
assert len(client_1.user_manager.users) == 1 # admin is created by default
assert len(client_1.user_manager.admins) == 1
# add admin account
action = (
"node-account-add-user",
{"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True},
)
agent.store_action(action)
game.step()
assert len(client_1.user_manager.users) == 2 # new user added
assert len(client_1.user_manager.admins) == 2
# add non admin account
action = (
"node-account-add-user",
{"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False},
)
agent.store_action(action)
game.step()
assert len(client_1.user_manager.users) == 3 # new user added
assert len(client_1.user_manager.admins) == 2
def test_user_account_disable_user_action(game_and_agent_fixture):
"""Tests the disable user account action."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.user_manager.add_user(username="test", password="password", is_admin=True)
assert len(client_1.user_manager.users) == 2 # new user added
assert len(client_1.user_manager.admins) == 2
test_user = client_1.user_manager.users.get("test")
assert test_user
assert test_user.disabled is not True
# disable test account
action = (
"node-account-disable-user",
{
"node_name": "client_1",
"username": "test",
},
)
agent.store_action(action)
game.step()
assert test_user.disabled
def test_user_account_change_password_action(game_and_agent_fixture):
"""Tests the change password user account action."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.user_manager.add_user(username="test", password="password", is_admin=True)
test_user = client_1.user_manager.users.get("test")
assert test_user.password == "password"
# change account password
action = (
"node-account-change-password",
{"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"},
)
agent.store_action(action)
game.step()
assert test_user.password == "2Hard_2_Hack"
def test_user_account_create_terminal_action(game_and_agent_fixture):
"""Tests that agents can use the terminal to create new users."""
game, agent = game_and_agent_fixture
router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4)
server_1 = game.simulation.network.get_node_by_hostname("server_1")
server_1_usm = server_1.software_manager.software["user-manager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"node-session-remote-login",
{
"node_name": "client_1",
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
# Create a new user account via terminal.
action = (
"node-send-remote-command",
{
"node_name": "client_1",
"remote_ip": str(server_1.network_interface[1].ip_address),
"command": ["service", "user-manager", "add_user", "new_user", "new_pass", True],
},
)
agent.store_action(action)
game.step()
new_user = server_1.user_manager.users.get("new_user")
assert new_user
assert new_user.password == "new_pass"
assert new_user.disabled is not True
def test_user_account_disable_terminal_action(game_and_agent_fixture):
"""Tests that agents can use the terminal to disable users."""
game, agent = game_and_agent_fixture
router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4)
server_1 = game.simulation.network.get_node_by_hostname("server_1")
server_1_usm = server_1.software_manager.software["user-manager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"node-session-remote-login",
{
"node_name": "client_1",
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
# Disable a user via terminal
action = (
"node-send-remote-command",
{
"node_name": "client_1",
"remote_ip": str(server_1.network_interface[1].ip_address),
"command": ["service", "user-manager", "disable_user", "user123"],
},
)
agent.store_action(action)
game.step()
new_user = server_1.user_manager.users.get("user123")
assert new_user
assert new_user.disabled is True

View File

@@ -44,6 +44,38 @@ def test_file_observation(simulation):
assert observation_state.get("health_status") == 3 # corrupted
def test_config_file_access_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
file_obs = FileObservation(
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
)
assert file_obs.high_file_access_threshold == 9
assert file_obs.med_file_access_threshold == 6
assert file_obs.low_file_access_threshold == 3
with pytest.raises(Exception):
# should throw an error
FileObservation(
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}},
)
with pytest.raises(Exception):
# should throw an error
FileObservation(
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}},
)
def test_folder_observation(simulation):
"""Test the folder observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")

View File

@@ -77,6 +77,14 @@ def test_nic(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
# The Simulation object created by the fixture also creates the
# NICObservation class with the NICObservation.capture_nmnme class variable
# set to False. Under normal (non-test) circumstances this class variable
# is set from a config file such as data_manipulation.yaml. So although
# capture_nmne is set to True in the NetworkInterface class it's still False
# in the NICObservation class so we set it now.
nic_obs.capture_nmne = True
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
@@ -115,14 +123,11 @@ def test_nic_categories(simulation):
assert nic_obs.low_nmne_threshold == 0 # default
@pytest.mark.skip(reason="Feature not implemented yet")
def test_config_nic_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
where=["network", "nodes", pc.config.hostname, "NICs", 1],
thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}},
include_nmne=True,
)
@@ -133,20 +138,16 @@ def test_config_nic_categories(simulation):
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
where=["network", "nodes", pc.config.hostname, "NICs", 1],
thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}},
include_nmne=True,
)
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
where=["network", "nodes", pc.config.hostname, "NICs", 1],
thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}},
include_nmne=True,
)

View File

@@ -39,6 +39,8 @@ def test_host_observation(simulation):
folders=[],
network_interfaces=[],
file_system_requires_scan=True,
services_requires_scan=True,
applications_requires_scan=True,
include_users=False,
)

View File

@@ -0,0 +1,28 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.io import PrimaiteIO
from tests import TEST_ASSETS_ROOT
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
def test_obs_data_in_log_file():
"""Create a log file of AgentHistoryItems and check observation data is
included. Assumes that data_manipulation.yaml has an agent labelled
'defender' with a non-null observation space.
The log file will be in:
primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions
"""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.reset()
for _ in range(10):
env.step(0)
env.reset()
io = PrimaiteIO()
path = io.generate_agent_actions_save_path(episode=1)
with open(path, "r") as f:
j = json.load(f)
assert type(j["0"]["defender"]["observation"]) == dict

View File

@@ -29,7 +29,9 @@ def test_service_observation(simulation):
ntp_server = pc.software_manager.software.get("ntp-server")
assert ntp_server
service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"])
service_obs = ServiceObservation(
where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True
)
assert service_obs.space["operating_status"] == spaces.Discrete(7)
assert service_obs.space["health_status"] == spaces.Discrete(5)
@@ -54,7 +56,9 @@ def test_application_observation(simulation):
web_browser: WebBrowser = pc.software_manager.software.get("web-browser")
assert web_browser
app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"])
app_obs = ApplicationObservation(
where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True
)
web_browser.close()
observation_state = app_obs.observe(simulation.describe_state())
@@ -69,3 +73,33 @@ def test_application_observation(simulation):
assert observation_state.get("health_status") == 1
assert observation_state.get("operating_status") == 1 # running
assert observation_state.get("num_executions") == 1
def test_application_executions_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
app_obs = ApplicationObservation(
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
applications_requires_scan=False,
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
)
assert app_obs.high_app_execution_threshold == 9
assert app_obs.med_app_execution_threshold == 6
assert app_obs.low_app_execution_threshold == 3
with pytest.raises(Exception):
# should throw an error
ApplicationObservation(
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
applications_requires_scan=False,
thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}},
)
with pytest.raises(Exception):
# should throw an error
ApplicationObservation(
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
applications_requires_scan=False,
thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}},
)

View File

@@ -7,6 +7,7 @@ import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator import SIM_OUTPUT
@pytest.fixture()
@@ -33,6 +34,11 @@ def test_rng_seed_set(create_env):
assert a == b
# Check that seed log file was created.
path = SIM_OUTPUT.path / "seed.log"
with open(path, "r") as file:
assert file
def test_rng_seed_unset(create_env):
"""Test with no RNG seed."""
@@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env):
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"]
assert a != b
def test_for_generated_seed():
"""
Show that setting generate_seed_value to true producess a valid seed.
"""
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
cfg["game"]["generate_seed_value"] = True
PrimaiteGymEnv(env_config=cfg)
path = SIM_OUTPUT.path / "seed.log"
with open(path, "r") as file:
data = file.read()
assert data.split(" ")[3] != None

View File

@@ -22,6 +22,7 @@ from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.software import SoftwareHealthState
@@ -107,7 +108,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
"""
Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation.
The acl starts off with 4 rules, and we add a rule, and check that the acl now has 5 rules.
The ACL starts off with 4 rules, and we add a rule, and check that the ACL now has 5 rules.
"""
game, agent = game_and_agent
@@ -164,11 +165,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
},
)
agent.store_action(action)
print(agent.most_recent_action)
game.step()
print(agent.most_recent_action)
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
print(router.acl.show())
assert router.acl.num_rules == 6
assert server_1.ping("10.0.2.3") # Can ping server_2
@@ -180,7 +179,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
# 1: Check that http traffic is going across the network nicely.
client_1 = game.simulation.network.get_node_by_hostname("client_1")
server_1 = game.simulation.network.get_node_by_hostname("server_1")
router = game.simulation.network.get_node_by_hostname("router")
router: Router = game.simulation.network.get_node_by_hostname("router")
assert router.acl.num_rules == 4
browser: WebBrowser = client_1.software_manager.software.get("web-browser")
browser.run()

View File

@@ -1,5 +1,11 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from itertools import product
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
@@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network):
assert web_nic_obs["outbound"] == expected_nmne
assert db_nic_obs["inbound"] == expected_nmne
uc2_network.apply_timestep(timestep=0)
def test_nmne_parameter_settings():
"""
Check that the four permutations of the values of capture_nmne and
include_nmne work as expected.
"""
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
DEFENDER = 3
for capture, include in product([True, False], [True, False]):
cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture
cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include
PrimaiteGymEnv(env_config=cfg)

View File

@@ -1,6 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from primaite.simulator.network.hardware.nodes.network.router import RouterARP
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP
from primaite.simulator.system.services.arp.arp import ARP
from primaite.utils.validation.port import PORT_LOOKUP
from tests.integration_tests.network.test_routing import multi_hop_network
@@ -48,3 +49,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network):
actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address)
assert actual_result == expected_result
def test_arp_not_affected_by_acl(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
router_1: Router = multi_hop_network.get_node_by_hostname("router_1")
# Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic
# as it operates a different layer within the network.
router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23)
pc_a_arp: ARP = pc_a.software_manager.arp
expected_result = router_1.network_interface[2].mac_address
actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address)
assert actual_result == expected_result

View File

@@ -1,10 +1,11 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import json
from typing import List
import pytest
import yaml
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.host_observations import HostObservation
@@ -136,3 +137,227 @@ class TestFileSystemRequiresScan:
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
)
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
class TestServicesRequiresScan:
@pytest.mark.parametrize(
("yaml_option_string", "expected_val"),
(
("services_requires_scan: true", True),
("services_requires_scan: false", False),
(" ", True),
),
)
def test_obs_config(self, yaml_option_string, expected_val):
"""Check that the default behaviour is to set service_requires_scan to True."""
obs_cfg_yaml = f"""
type: custom
options:
components:
- type: nodes
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: web-server
- service_name: dns-client
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
services:
- service_name: ftp-server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 3
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
{yaml_option_string}
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: links
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: none
label: ICS
options: {{}}
"""
cfg = yaml.safe_load(obs_cfg_yaml)
manager = ObservationManager.from_config(cfg)
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
for i, host in enumerate(hosts):
services: List[ServiceObservation] = host.services
for j, service in enumerate(services):
val = service.services_requires_scan
print(f"host {i} service {j} {val}")
assert val == expected_val # Make sure services require scan by default
def test_services_requires_scan(self):
state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1}
obs_requiring_scan = ServiceObservation([], services_requires_scan=True)
assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value
obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False)
assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value
class TestApplicationsRequiresScan:
@pytest.mark.parametrize(
("yaml_option_string", "expected_val"),
(
("applications_requires_scan: true", True),
("applications_requires_scan: false", False),
(" ", True),
),
)
def test_obs_config(self, yaml_option_string, expected_val):
"""Check that the default behaviour is to set applications_requires_scan to True."""
obs_cfg_yaml = f"""
type: custom
options:
components:
- type: nodes
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
applications:
- application_name: web-browser
- hostname: client_2
applications:
- application_name: web-browser
- application_name: database-client
num_services: 0
num_applications: 3
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
{yaml_option_string}
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: links
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: none
label: ICS
options: {{}}
"""
cfg = yaml.safe_load(obs_cfg_yaml)
manager = ObservationManager.from_config(cfg)
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
for i, host in enumerate(hosts):
services: List[ServiceObservation] = host.services
for j, service in enumerate(services):
val = service.services_requires_scan
print(f"host {i} service {j} {val}")
assert val == expected_val # Make sure applications require scan by default
def test_applications_requires_scan(self):
state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1}
obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True)
assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value
obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False)
assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value

View File

@@ -73,7 +73,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client):
"""Method send_file should return false if no file to send."""
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None

View File

@@ -6,6 +6,7 @@ import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@@ -442,3 +443,59 @@ def test_terminal_connection_timeout(basic_network):
assert len(computer_b.user_session_manager.remote_sessions) == 0
assert not remote_connection.is_active
def test_terminal_last_response_updates(basic_network):
"""Test that the _last_response within Terminal correctly updates."""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
assert terminal_a.last_response is None
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
# Last response should be a successful logon
assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"})
remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"])
# Last response should now update following successful install
assert terminal_a.last_response == RequestResponse(status="success", data={})
remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"])
# Last response should now update to success, but with supplied reason.
assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"})
remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False])
# Check file was created.
assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf")
# Last response should be confirmation of file creation.
assert terminal_a.last_response == RequestResponse(
status="success",
data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400},
)
remote_connection.execute(
command=[
"service",
"ftp-client",
"send",
{
"dest_ip_address": "192.168.0.2",
"src_folder": "folder123",
"src_file_name": "cat.pdf",
"dest_folder": "root",
"dest_file_name": "cat.pdf",
},
]
)
assert terminal_a.last_response == RequestResponse(
status="failure",
data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"},
)