2755 Add ability to extend HostNode, NetworkNode, Service and Application outside PrimAITE.

This commit is contained in:
=
2024-09-04 15:49:37 +01:00
parent 08f742b3ec
commit 310876cd3b
17 changed files with 1926 additions and 19 deletions

View File

@@ -59,3 +59,17 @@ def data_manipulation_marl_config_path() -> Path:
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path
def get_extended_config_path() -> Path:
"""
Get the path to an 'extended' example config that contains nodes using the extension framework
:return: Path to the extended example config
:rtype: Path
"""
path = _EXAMPLE_CFG / "extended_config.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -20,9 +20,10 @@ from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
@@ -278,8 +279,25 @@ class PrimaiteGame:
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
new_node = None
if n_type == "computer":
# Handle extended nodes
if n_type.lower() in HostNode._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()])
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](
**node_cfg
)
# Default PrimAITE nodes
elif n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
@@ -351,10 +369,18 @@ class PrimaiteGame:
for service_cfg in node_cfg["services"]:
new_service = None
service_type = service_cfg["type"]
if service_type in SERVICE_TYPES_MAPPING:
service_class = None
# Handle extended services
if service_type.lower() in Service._registry:
service_class = Service._registry[service_type.lower()]
elif service_type in SERVICE_TYPES_MAPPING:
service_class = SERVICE_TYPES_MAPPING[service_type]
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
new_service = new_node.software_manager.software[service_type]
new_node.software_manager.install(service_class)
new_service = new_node.software_manager.software[service_class.__name__]
# fixing duration for the service
if "fix_duration" in service_cfg.get("options", {}):
@@ -398,8 +424,8 @@ class PrimaiteGame:
new_application = None
application_type = application_cfg["type"]
if application_type in Application._application_registry:
new_node.software_manager.install(Application._application_registry[application_type])
if application_type in Application._registry:
new_node.software_manager.install(Application._registry[application_type])
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application

View File

@@ -12,7 +12,9 @@ from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.airspace import AirSpace
from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
@@ -128,6 +130,16 @@ class Network(SimComponent):
def firewall_nodes(self) -> List[Node]:
"""The Firewalls in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"]
@property
def extended_hostnodes(self) -> List[Node]:
"""Extended nodes that inherited HostNode in the network"""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry]
@property
def extended_networknodes(self) -> List[Node]:
"""Extended nodes that inherited NetworkNode in the network"""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry]
@property
def printer_nodes(self) -> List[Node]:
@@ -160,6 +172,7 @@ class Network(SimComponent):
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
if markdown:

View File

@@ -1699,7 +1699,7 @@ class Node(SimComponent):
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
return RequestResponse(status="success", data={"reason": "already installed"})
application_class = Application._application_registry[application_name]
application_class = Application._registry[application_name]
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
self.applications[application_instance.uuid] = application_instance

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Dict, Optional, Type
from primaite import getLogger
from primaite.simulator.network.hardware.base import (
@@ -325,10 +325,30 @@ class HostNode(Node):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
_registry: ClassVar[Dict[str, Type["HostNode"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
"""
Register a hostnode type.
:param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == 'default':
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@property
def nmap(self) -> Optional[NMAP]:
"""

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import Optional
from typing import Any, ClassVar, Dict, Optional, Type
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -16,6 +16,25 @@ class NetworkNode(Node):
provide functionality for receiving and processing frames received on their network interfaces.
"""
_registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
"""
Register a networknode type.
:param identifier: Uniquely specifies an networknode class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an networknode with a name that is already allocated.
"""
if identifier == 'default':
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new networknode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@abstractmethod
def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface):
"""

View File

@@ -41,10 +41,10 @@ class Application(IOSoftware):
install_countdown: Optional[int] = None
"The countdown to the end of the installation process. None if not currently installing"
_application_registry: ClassVar[Dict[str, Type["Application"]]] = {}
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
"""
Register an application type.
@@ -52,10 +52,12 @@ class Application(IOSoftware):
:type identifier: str
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
if identifier == 'default':
return
super().__init_subclass__(**kwargs)
if identifier in cls._application_registry:
if identifier in cls._registry:
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
cls._application_registry[identifier] = cls
cls._registry[identifier] = cls
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, ClassVar, Dict, Optional, Type
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
@@ -46,9 +46,29 @@ class Service(IOSoftware):
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
_registry: ClassVar[Dict[str, Type["Service"]]] = {}
"""Registry of service types. Automatically populated when subclasses are defined."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
"""
Register a hostnode type.
:param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == 'default':
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
def _can_perform_action(self) -> bool:
"""
Checks if the service can perform actions.

View File

@@ -0,0 +1,951 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
sys_log_level: WARNING
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
node_id: 0
19:
action: "NODE_SHUTDOWN"
options:
node_id: 0
20:
action: NODE_STARTUP
options:
node_id: 0
21:
action: NODE_RESET
options:
node_id: 0
22:
action: "NODE_OS_SCAN"
options:
node_id: 1
23:
action: "NODE_SHUTDOWN"
options:
node_id: 1
24:
action: NODE_STARTUP
options:
node_id: 1
25:
action: NODE_RESET
options:
node_id: 1
26: # old action num: 18
action: "NODE_OS_SCAN"
options:
node_id: 2
27:
action: "NODE_SHUTDOWN"
options:
node_id: 2
28:
action: NODE_STARTUP
options:
node_id: 2
29:
action: NODE_RESET
options:
node_id: 2
30:
action: "NODE_OS_SCAN"
options:
node_id: 3
31:
action: "NODE_SHUTDOWN"
options:
node_id: 3
32:
action: NODE_STARTUP
options:
node_id: 3
33:
action: NODE_RESET
options:
node_id: 3
34:
action: "NODE_OS_SCAN"
options:
node_id: 4
35:
action: "NODE_SHUTDOWN"
options:
node_id: 4
36:
action: NODE_STARTUP
options:
node_id: 4
37:
action: NODE_RESET
options:
node_id: 4
38:
action: "NODE_OS_SCAN"
options:
node_id: 5
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
40: # old action num: 20
action: NODE_STARTUP
options:
node_id: 5
41: # old action num: 21
action: NODE_RESET
options:
node_id: 5
42:
action: "NODE_OS_SCAN"
options:
node_id: 6
43:
action: "NODE_SHUTDOWN"
options:
node_id: 6
44:
action: NODE_STARTUP
options:
node_id: 6
45:
action: NODE_RESET
options:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 0
53: # old action num: 29
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 1
54: # old action num: 30
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 2
55: # old action num: 31
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 3
56: # old action num: 32
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 4
57: # old action num: 33
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 5
58: # old action num: 34
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 6
59: # old action num: 35
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 7
60: # old action num: 36
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 8
61: # old action num: 37
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 9
62: # old action num: 38
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
agent_settings:
flatten_obs: true
action_masking: true
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: router_1
type: router
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- hostname: switch_1
type: switch
num_ports: 8
- hostname: switch_2
type: gigaswitch
num_ports: 8
- hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: FTPClient
- hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: FTPServer
- hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- hostname: client_1
type: supercomputer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: ExtendedApplication
options:
target_url: http://arcd.com/users/
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: ExtendedService
options:
backup_server_ip: 192.168.1.16
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: web_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_hostname: database_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -86,7 +86,7 @@ def test_node_software_install():
assert client_2.software_manager.software.get(software.__name__) is not None
# check that applications have been installed on client 1
for applications in Application._application_registry:
for applications in Application._registry:
assert client_1.software_manager.software.get(applications) is not None
# check that services have been installed on client 1

View File

@@ -51,7 +51,7 @@ def test_fix_duration_set_from_config():
# in config - applications take 1 timestep to fix
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications = set(Application._registry) - set(TestApplications)
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None

View File

@@ -0,0 +1,220 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
HttpResponsePacket,
HttpStatusCode,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
_LOGGER = getLogger(__name__)
class ExtendedApplication(Application, identifier="ExtendedApplication"):
"""
Clone of web browser that uses the extension framework instead of being part of PrimAITE directly.
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
latest_response: Optional[HttpResponsePacket] = None
"""Keeps track of the latest HTTP response."""
history: List["BrowserHistoryItem"] = []
"""Keep a log of visited websites and information about the visit, such as response code."""
def __init__(self, **kwargs):
kwargs["name"] = "ExtendedApplication"
kwargs["protocol"] = IPProtocol.TCP
# default for web is port 80
if kwargs.get("port") is None:
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.run()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="execute",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.get_webpage())
), # noqa
)
return rm
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the WebBrowser.
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
"""
state = super().describe_state()
state["history"] = [hist_item.state() for hist_item in self.history]
return state
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.
This should send a request to the web server which also requests for a list of users
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = url or self.target_url
if not self._can_perform_action():
return False
self.num_executions += 1 # trying to connect counts as an execution
# reset latest response
self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND)
try:
parsed_url = urlparse(url)
except Exception:
self.sys_log.warning(f"{url} is not a valid URL")
return False
# get the IP address of the domain name via DNS
dns_client: DNSClient = self.software_manager.software.get("DNSClient")
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
# if domain does not exist, the request fails
if domain_exists:
# set current domain name IP address
self.domain_name_ip_address = dns_client.dns_cache[parsed_url.hostname]
else:
# check if url is an ip address
try:
self.domain_name_ip_address = IPv4Address(parsed_url.hostname)
except Exception:
# unable to deal with this request
self.sys_log.warning(f"{self.name}: Unable to resolve URL {url}")
return False
# create HTTPRequest payload
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url=url)
# send request - As part of the self.send call, a response will be received and stored in the
# self.latest_response variable
if self.send(
payload=payload,
dest_ip_address=self.domain_name_ip_address,
dest_port=parsed_url.port if parsed_url.port else Port.HTTP,
):
self.sys_log.info(
f"{self.name}: Received HTTP {payload.request_method.name} "
f"Response {payload.request_url} - {self.latest_response.status_code.value}"
)
self.history.append(
WebBrowser.BrowserHistoryItem(
url=url,
status=self.BrowserHistoryItem._HistoryItemStatus.LOADED,
response_code=self.latest_response.status_code,
)
)
return self.latest_response.status_code is HttpStatusCode.OK
else:
self.sys_log.warning(f"{self.name}: Error sending Http Packet")
self.sys_log.debug(f"{self.name}: {payload=}")
self.history.append(
WebBrowser.BrowserHistoryItem(
url=url, status=self.BrowserHistoryItem._HistoryItemStatus.SERVER_UNREACHABLE
)
)
return False
def send(
self,
payload: HttpRequestPacket,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = Port.HTTP,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
self.sys_log.info(f"{self.name}: Sending HTTP {payload.request_method.name} {payload.request_url}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs
)
def receive(self, payload: HttpResponsePacket, session_id: Optional[str] = None, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
if not isinstance(payload, HttpResponsePacket):
self.sys_log.warning(f"{self.name} received a packet that is not an HttpResponsePacket")
self.sys_log.debug(f"{self.name}: {payload=}")
return False
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
self.latest_response = payload
return True
class BrowserHistoryItem(BaseModel):
"""Simple representation of browser history, used for tracking success of web requests to calculate rewards."""
model_config = ConfigDict(extra="forbid")
"""Error if incorrect specification."""
url: str
"""The URL that was attempted to be fetched by the browser"""
class _HistoryItemStatus(Enum):
NOT_SENT = "NOT_SENT"
PENDING = "PENDING"
SERVER_UNREACHABLE = "SERVER_UNREACHABLE"
LOADED = "LOADED"
status: _HistoryItemStatus = _HistoryItemStatus.PENDING
response_code: Optional[HttpStatusCode] = None
"""HTTP response code that was received, or PENDING if a response was not yet received."""
def state(self) -> Dict:
"""Return the contents of this dataclass as a dict for use with describe_state method."""
if self.status == self._HistoryItemStatus.LOADED:
outcome = self.response_code.value
else:
outcome = self.status.value
return {"url": self.url, "outcome": outcome}

View File

@@ -0,0 +1,121 @@
from typing import Dict
from prettytable import MARKDOWN, PrettyTable
from primaite import _LOGGER
from primaite.exceptions import NetworkError
from primaite.simulator.network.hardware.base import Link
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.switch import SwitchPort
from primaite.simulator.network.transmission.data_link_layer import Frame
class GigaSwitch(NetworkNode, identifier="gigaswitch"):
"""
A class representing a Layer 2 network switch.
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
"The SwitchPorts on the Switch by port id."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
print('--- Extended Component: GigaSwitch ---')
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
"""
Private method to add an entry to the MAC address table.
:param mac_address: MAC address to be added.
:param switch_port: Corresponding SwitchPort object.
"""
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def receive_frame(self, frame: Frame, from_network_interface: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame being received.
:param from_network_interface: The SwitchPort that received the frame.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, from_network_interface)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.network_interface.values():
if port.enabled and port != from_network_interface:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.network_interface.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()

View File

@@ -0,0 +1,43 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.utils.validators import IPV4Address
class SuperComputer(HostNode, identifier="supercomputer"):
"""
A basic Computer class.
Example:
>>> pc_a = Computer(
hostname="pc_a",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1"
)
>>> pc_a.power_on()
Instances of computer come 'pre-packaged' with the following:
* Core Functionality:
* Packet Capture
* Sys Log
* Services:
* ARP Service
* ICMP Service
* DNS Client
* FTP Client
* NTP Client
* Applications:
* Web Browser
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
print('--- Extended Component: SuperComputer ---')
super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs)
pass

View File

@@ -0,0 +1,426 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
class ExtendedService(Service, identifier='extendedservice'):
"""
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
"""Directory of latest backup."""
latest_backup_file_name: str = None
"""File name of latest backup."""
def __init__(self, **kwargs):
kwargs["name"] = "ExtendedService"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._create_db_file()
if kwargs.get('options'):
opt = kwargs["options"]
self.password = opt.get("db_password", None)
if "backup_server_ip" in opt:
self.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
def install(self):
"""
Perform first-time setup of the ExtendedService.
Installs an instance of FTPClient on the Node to enable database backup if it isn't installed already.
"""
super().install()
if not self.parent.software_manager.software.get("FTPClient"):
self.parent.sys_log.info(f"{self.name}: Installing FTPClient to enable database backups")
self.parent.software_manager.install(FTPClient)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.
:param: backup_server_ip: The IP address of the backup server
"""
self.backup_server_ip = backup_server
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
# check if the backup server was configured
if self.backup_server_ip is None:
self.sys_log.warning(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
if not ftp_client_service:
self.sys_log.error(
f"{self.name}: Failed to perform database backup as the FTPClient software is not installed"
)
return False
# send backup copy of database file to FTP server
if not self.db_file:
self.sys_log.error(f"{self.name}: Attempted to backup database file but it doesn't exist.")
return False
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server_ip,
src_file_name=self.db_file.name,
src_folder_name="database",
dest_folder_name=str(self.uuid),
# Prevent's a filename clash with the real DatabaseService service implementation
dest_file_name="extended_service_database.db",
)
if response:
return True
self.sys_log.error("Unable to create database backup.")
return False
def restore_backup(self) -> bool:
"""Restore a backup from backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
if not ftp_client_service:
self.sys_log.error(
f"{self.name}: Failed to restore database backup as the FTPClient software is not installed"
)
return False
# retrieve backup file from backup server
response = ftp_client_service.request_file(
src_folder_name=str(self.uuid),
src_file_name="extended_service_database.db",
dest_folder_name="downloads",
dest_file_name="extended_service_database.db",
dest_ip_address=self.backup_server_ip,
)
if not response:
self.sys_log.error("Unable to restore database backup.")
return False
old_visible_state = SoftwareHealthState.GOOD
# get db file regardless of whether or not it was deleted
db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True)
if db_file is None:
self.sys_log.warning("Database file not initialised.")
return False
# if the file was deleted, get the old visible health state
if db_file.deleted:
old_visible_state = db_file.visible_health_status
else:
old_visible_state = self.db_file.visible_health_status
self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db")
# replace db file
self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database")
if self.db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.db_file.visible_health_status = old_visible_state
self.set_health_state(SoftwareHealthState.GOOD)
return True
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self.file_system.create_file(folder_name="database", file_name="extended_service_database.db")
@property
def db_file(self) -> File:
"""Returns the database file."""
return self.file_system.get_file(folder_name="database", file_name="extended_service_database.db")
def _return_database_folder(self) -> Folder:
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
def _generate_connection_id(self) -> str:
"""Generate a unique connection ID."""
return str(uuid4())
def _process_connect(
self,
src_ip: IPv4Address,
connection_request_id: str,
password: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[str, Union[int, Dict[str, bool]]]:
"""Process an incoming connection request.
:param connection_id: A unique identifier for the connection
:type connection_id: str
:param password: Supplied password. It must match self.password for connection success, defaults to None
:type password: Optional[str], optional
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
f"capacity."
)
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
SoftwareHealthState.COMPROMISED,
]:
if self.password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
# try to create connection
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
f"returning status code 500"
)
else:
status_code = 401 # Unauthorised
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
f"(incorrect password), returning status code 401"
)
else:
status_code = 404 # service not found
return {
"status_code": status_code,
"type": "connect_response",
"response": status_code == 200,
"connection_id": connection_id,
"connection_request_id": connection_request_id,
}
def _process_sql(
self,
query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"],
query_id: str,
connection_id: Optional[str] = None,
) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
Possible queries:
- SELECT : returns the data
- DELETE : deletes the data
- INSERT : inserts the data
- ENCRYPT : corrupts the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
if not self.db_file:
self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.")
return {"status_code": 404, "type": "sql", "data": False}
if self.health_state_actual is not SoftwareHealthState.GOOD:
self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.")
return {"status_code": 500, "type": "sql", "data": False}
if query == "SELECT":
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {
"status_code": 200,
"type": "sql",
"data": True,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "type": "sql", "data": False}
elif query == "DELETE":
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "ENCRYPT":
self.file_system.num_file_creations += 1
self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT
self.db_file.num_access += 1
database_folder = self._return_database_folder()
database_folder.health_status = FileSystemItemHealthStatus.CORRUPT
self.file_system.num_file_deletions += 1
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "INSERT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "type": "sql", "data": False}
elif query == "SELECT * FROM pg_stat_activity":
# Check if the connection is active.
if self.health_state_actual == SoftwareHealthState.GOOD:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 401, "data": False}
else:
# Invalid query
self.sys_log.warning(f"{self.name}: Invalid {query}")
return {"status_code": 500, "data": False}
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Processes the incoming SQL payload and sends the result back.
:param payload: The SQL query to be executed.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
result = {"status_code": 500, "data": []}
# if server service is down, return error
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
src_ip = kwargs.get("frame").ip.src_ip_address
result = self._process_connect(
src_ip=src_ip,
password=payload.get("password"),
connection_request_id=payload.get("connection_request_id"),
session_id=session_id,
)
elif payload["type"] == "disconnect":
if payload["connection_id"] in self.connections:
connection_id = payload["connection_id"]
connected_ip_address = self.connections[connection_id]["ip_address"]
frame = kwargs.get("frame")
if connected_ip_address == frame.ip.src_ip_address:
self.sys_log.info(
f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}"
)
self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False)
else:
self.sys_log.warning(
f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source "
f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})"
)
elif payload["type"] == "sql":
if payload.get("connection_id") in self.connections:
result = self._process_sql(
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
)
else:
result = {"status_code": 401, "type": "sql"}
else:
self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload")
self.send(payload=result, session_id=session_id)
return True
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Send a SQL response back down to the SessionManager.
:param payload: The SQL query results.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200
def apply_timestep(self, timestep: int) -> None:
"""
Apply a single timestep of simulation dynamics to this service.
Here at the first step, the database backup is created, in addition to normal service update logic.
"""
if timestep == 1:
self.backup_database()
return super().apply_timestep(timestep)
def _update_fix_status(self) -> None:
"""Perform a database restore when the FIXING countdown is finished."""
super()._update_fix_status()
if self._fixing_countdown is None:
self.restore_backup()

View File

@@ -0,0 +1,32 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.config.load import get_extended_config_path
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config
import os
# Import the extended components so that PrimAITE registers them
from tests.integration_tests.extensions.nodes.super_computer import SuperComputer
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
from tests.integration_tests.extensions.services.extended_service import ExtendedService
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
def test_extended_example_config():
"""Test that the example config can be parsed properly."""
config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml")
game = load_config(config_path)
network: Network = game.simulation.network
assert len(network.nodes) == 10 # 10 nodes in example network
assert len(network.computer_nodes) == 1
assert len(network.router_nodes) == 1 # 1 router in network
assert len(network.switch_nodes) == 1 # 1 switches in network
assert len(network.server_nodes) == 5 # 5 servers in network
assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode
assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode
assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software
assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software

View File

@@ -8,7 +8,7 @@ def test_adding_to_app_registry():
class temp_application(Application, identifier="temp_app"):
pass
assert Application._application_registry["temp_app"] is temp_application
assert Application._registry["temp_app"] is temp_application
with pytest.raises(ValueError):
@@ -19,4 +19,4 @@ def test_adding_to_app_registry():
# Because pytest doesn't reimport classes from modules, registering this temporary test application will change the
# state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class.
del temp_application
Application._application_registry.pop("temp_app")
Application._registry.pop("temp_app")