Merge remote-tracking branch 'origin/dev' into bugfix/2143-node-service-patch-main

This commit is contained in:
Marek Wolan
2024-01-03 14:43:46 +00:00
44 changed files with 1390 additions and 368 deletions

View File

@@ -18,36 +18,43 @@ parameters:
py: '3.8'
img: 'ubuntu-latest'
every_time: false
publish_coverage: false
- job_name: 'UbuntuPython310'
py: '3.10'
img: 'ubuntu-latest'
every_time: true
publish_coverage: true
- job_name: 'WindowsPython38'
py: '3.8'
img: 'windows-latest'
every_time: false
publish_coverage: false
- job_name: 'WindowsPython310'
py: '3.10'
img: 'windows-latest'
every_time: false
publish_coverage: false
- job_name: 'MacOSPython38'
py: '3.8'
img: 'macOS-latest'
every_time: false
publish_coverage: false
- job_name: 'MacOSPython310'
py: '3.10'
img: 'macOS-latest'
every_time: false
publish_coverage: false
stages:
- stage: Test
jobs:
- ${{ each item in parameters.matrix }}:
- job: ${{ item.job_name }}
timeoutInMinutes: 90
cancelTimeoutInMinutes: 1
pool:
vmImage: ${{ item.img }}
condition: or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} )
condition: and(succeeded(), or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} ))
steps:
- task: UsePythonVersion@0
@@ -109,12 +116,12 @@ stages:
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
# publish the html report - so we can debug the coverage if needed
condition: ${{ item.every_time }} # should only be run once
condition: ${{ item.publish_coverage }} # should only be run once
artifact: coverage_report
- task: PublishCodeCoverageResults@2
# publish the code coverage so it can be viewed in the run coverage page
condition: ${{ item.every_time }} # should only be run once
condition: ${{ item.publish_coverage }} # should only be run once
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'

1
.gitignore vendored
View File

@@ -156,3 +156,4 @@ benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py
sandbox.ipynb

View File

@@ -37,6 +37,7 @@ SessionManager.
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
- NTP Services: `NTPClient` and `NTPServer`
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`

View File

@@ -0,0 +1,54 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
NTP Client Server
=================
NTP Server
----------
The ``NTPServer`` provides a NTP Server simulation by extending the base Service class.
NTP Client
----------
The ``NTPClient`` provides a NTP Client simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Simulates NTP requests and NTPPacket transfer across a network
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 123 by default.
Implementation
^^^^^^^^^^^^^^
- NTP request and responses use a ``NTPPacket`` object
- Extends Service class for integration with ``SoftwareManager``.
NTP Client
----------
The NTPClient provides a client interface for connecting to the ``NTPServer``.
Key features
^^^^^^^^^^^^
- Connects to the ``NTPServer`` via the ``SoftwareManager``.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 123 by default.
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Provides easy interface for Nodes to find IP addresses via domain names.
- Extends base Service class.

View File

@@ -105,25 +105,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +138,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -510,7 +510,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -518,8 +518,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -611,7 +611,7 @@ class ActionManager:
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
class DataManipulationAgent(AbstractScriptedAgent):

View File

@@ -40,8 +40,7 @@ class AbstractObservation(ABC):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
@@ -53,12 +52,12 @@ class FileObservation(AbstractObservation):
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
@@ -120,7 +119,7 @@ class ServiceObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -162,7 +161,7 @@ class ServiceObservation(AbstractObservation):
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
return cls(where=parent_where + ["services", config["service_name"]])
class LinkObservation(AbstractObservation):
@@ -179,7 +178,7 @@ class LinkObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -242,7 +241,7 @@ class FolderObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
@@ -321,7 +320,7 @@ class FolderObservation(AbstractObservation):
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
@@ -347,7 +346,7 @@ class NicObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
@@ -384,12 +383,12 @@ class NicObservation(AbstractObservation):
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
return cls(where=parent_where + ["NICs", config["nic_num"]])
class NodeObservation(AbstractObservation):
@@ -412,9 +411,9 @@ class NodeObservation(AbstractObservation):
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
@@ -423,7 +422,7 @@ class NodeObservation(AbstractObservation):
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:param nics: Mapping between position in observation space and NIC idx, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
@@ -541,11 +540,11 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = game.ref_map_nodes[config["node_ref"]]
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_uuid]
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_uuid]
where = parent_where + ["nodes", node_hostname]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
@@ -556,8 +555,8 @@ class NodeObservation(AbstractObservation):
)
for c in folder_configs
]
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
@@ -598,7 +597,7 @@ class AclObservation(AbstractObservation):
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_uuid>,'acl','acl']
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
@@ -711,12 +710,12 @@ class AclObservation(AbstractObservation):
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -846,6 +845,7 @@ class UC2BlueObservation(AbstractObservation):
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]

View File

@@ -82,11 +82,11 @@ class DummyReward(AbstractReward):
class DatabaseFileIntegrity(AbstractReward):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the database file.
:type node_uuid: str
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
@@ -95,7 +95,7 @@ class DatabaseFileIntegrity(AbstractReward):
self.location_in_state = [
"network",
"nodes",
node_uuid,
node_hostname,
"file_system",
"folders",
folder_name,
@@ -129,49 +129,29 @@ class DatabaseFileIntegrity(AbstractReward):
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_ref = config.get("node_ref")
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not node_ref:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not folder_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not file_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
f"found in the simulation"
)
)
return DummyReward() # TODO: better error handling
if not (node_hostname and folder_name and file_name):
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_uuid: str, service_uuid: str) -> None:
def __init__(self, node_hostname: str, service_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the web server service.
:type node_uuid: str
:param service_uuid: UUID of the web server service.
:type service_uuid: str
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
"""
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state.
@@ -203,26 +183,17 @@ class WebServer404Penalty(AbstractReward):
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_ref = config.get("node_ref")
service_ref = config.get("service_ref")
if not (node_ref and service_ref):
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref]
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: consider erroring here as well
raise ValueError(msg)
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
return cls(node_hostname=node_hostname, service_name=service_name)
class RewardFunction:

View File

@@ -20,13 +20,15 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -266,6 +268,8 @@ class PrimaiteGame:
"WebServer": WebServer,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
}
if service_type in service_types_mapping:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
@@ -314,6 +318,7 @@ class PrimaiteGame:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),

View File

@@ -102,7 +102,7 @@ class DomainController(SimComponent):
:rtype: Dict
"""
state = super().describe_state()
state.update({"accounts": {uuid: acct.describe_state() for uuid, acct in self.accounts.items()}})
state.update({"accounts": {acct.username: acct.describe_state() for acct in self.accounts.values()}})
return state
def _register_account(self, account: Account) -> None:

View File

@@ -199,10 +199,24 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
for uuid, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
port_a = link.endpoint_a._port_num_on_node
port_b = link.endpoint_b._port_num_on_node
state["links"][uuid] = link.describe_state()
state["links"][uuid]["hostname_a"] = hostname_a
state["links"][uuid]["hostname_b"] = hostname_b
state["links"][uuid]["port_a"] = port_a
state["links"][uuid]["port_b"] = port_b
return state
def add_node(self, node: Node) -> None:

View File

@@ -91,6 +91,8 @@ class NIC(SimComponent):
"Indicates if the NIC supports Wake-on-LAN functionality."
_connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
_port_num_on_node: Optional[int] = None
"Which port number is assigned on this NIC"
_connected_link: Optional[Link] = None
"The Link to which the NIC is connected."
enabled: bool = False
@@ -148,7 +150,7 @@ class NIC(SimComponent):
state = super().describe_state()
state.update(
{
"ip_adress": str(self.ip_address),
"ip_address": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"mac_address": self.mac_address,
"speed": self.speed,
@@ -311,6 +313,8 @@ class SwitchPort(SimComponent):
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
_connected_node: Optional[Node] = None
"The Node to which the SwitchPort is connected."
_port_num_on_node: Optional[int] = None
"The port num on the connected node."
_connected_link: Optional[Link] = None
"The Link to which the SwitchPort is connected."
enabled: bool = False
@@ -497,8 +501,8 @@ class Link(SimComponent):
state = super().describe_state()
state.update(
{
"endpoint_a": self.endpoint_a.uuid,
"endpoint_b": self.endpoint_b.uuid,
"endpoint_a": self.endpoint_a.uuid, # TODO: consider if using UUID is the best way to do this
"endpoint_b": self.endpoint_b.uuid, # TODO: consider if using UUID is the best way to do this
"bandwidth": self.bandwidth,
"current_load": self.current_load,
}
@@ -1094,12 +1098,12 @@ class Node(SimComponent):
{
"hostname": self.hostname,
"operating_state": self.operating_state.value,
"NICs": {uuid: nic.describe_state() for uuid, nic in self.nics.items()},
"NICs": {eth_num: nic.describe_state() for eth_num, nic in self.ethernet_port.items()},
# "switch_ports": {uuid, sp for uuid, sp in self.switch_ports.items()},
"file_system": self.file_system.describe_state(),
"applications": {uuid: app.describe_state() for uuid, app in self.applications.items()},
"services": {uuid: svc.describe_state() for uuid, svc in self.services.items()},
"process": {uuid: proc.describe_state() for uuid, proc in self.processes.items()},
"applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red,
}
)
@@ -1316,6 +1320,7 @@ class Node(SimComponent):
self.nics[nic.uuid] = nic
self.ethernet_port[len(self.nics)] = nic
nic._connected_node = self
nic._port_num_on_node = len(self.nics)
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:

View File

@@ -30,6 +30,7 @@ class Switch(Node):
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port._connected_node = self
port._port_num_on_node = port_num
port.parent = self
port.port_num = port_num

View File

@@ -9,10 +9,10 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.web_server.web_server import WebServer
@@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network:
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None) # noqa
database_service._process_sql(ddl, None, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None) # noqa
database_service._process_sql(insert_statement, None, None) # noqa
# Web Server
web_server = Server(

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
class NTPReply(BaseModel):
"""Represents a NTP Reply packet."""
ntp_datetime: datetime
"NTP datetime object set by NTP Server."
class NTPPacket(DataPacket):
"""
Represents the NTP layer of a network frame.
:param ntp_request: NTPRequest packet from NTP client.
:param ntp_reply: NTPReply packet from NTP Server.
"""
ntp_reply: Optional[NTPReply] = None
def generate_reply(self, ntp_server_time: datetime) -> NTPPacket:
"""Generate a NTPPacket containing the time in a NTPReply object.
:param time: datetime object representing the time from the NTP server.
:return: A new NTPPacket object.
"""
self.ntp_reply = NTPReply(ntp_datetime=ntp_server_time)
return self

View File

@@ -5,7 +5,7 @@ from uuid import uuid4
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
@@ -23,7 +23,6 @@ class DatabaseClient(Application):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
def __init__(self, **kwargs):
@@ -66,18 +65,24 @@ class DatabaseClient(Application):
self.server_password = server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self) -> bool:
def connect(self, connection_id: Optional[str] = None) -> bool:
"""Connect to a Database Service."""
if not self._can_perform_action():
return False
if not self.connected:
return self._connect(self.server_ip_address, self.server_password)
# already connected
return True
if not connection_id:
connection_id = str(uuid4())
return self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
)
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
self,
server_ip_address: IPv4Address,
connection_id: Optional[str] = None,
password: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""
Connects the DatabaseClient to the DatabaseServer.
@@ -92,33 +97,58 @@ class DatabaseClient(Application):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connected:
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
if self.connections.get(connection_id):
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
)
self.server_ip_address = server_ip_address
return self.connected
return True
else:
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
)
return False
payload = {"type": "connect_request", "password": password}
payload = {
"type": "connect_request",
"password": password,
"connection_id": connection_id,
}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(server_ip_address, password, True)
return self._connect(
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
)
def disconnect(self):
def disconnect(self, connection_id: Optional[str] = None) -> bool:
"""Disconnect from the Database Service."""
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
)
if not self._can_perform_action():
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
return False
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
self.server_ip_address = None
self.connected = False
# if there are no connections - nothing to disconnect
if not len(self.connections):
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
return False
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
# if no connection provided, disconnect the first connection
if not connection_id:
connection_id = list(self.connections.keys())[0]
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
self.remove_connection(connection_id=connection_id)
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
)
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
"""
Send a query to the connected database server.
@@ -141,19 +171,17 @@ class DatabaseClient(Application):
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id},
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
def query(self, sql: str, is_reattempt: bool = False) -> bool:
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
"""
Send a query to the Database Service.
@@ -164,20 +192,17 @@ class DatabaseClient(Application):
if not self._can_perform_action():
return False
if self.connected:
query_id = str(uuid4())
if connection_id is None:
connection_id = str(uuid4())
if not self.connections.get(connection_id):
if not self.connect(connection_id=connection_id):
return False
# Initialise the tracker of this ID to False
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
else:
if is_reattempt:
return False
if not self.connect():
return False
self.query(sql=sql, is_reattempt=True)
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
@@ -192,13 +217,13 @@ class DatabaseClient(Application):
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
self.connected = payload["response"] == True
if payload["response"] is True:
# add connection
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
elif payload["type"] == "sql":
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
_LOGGER.debug(f"Received payload {payload}")
else:
self.connected = False
return True

View File

@@ -5,7 +5,6 @@ from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
@@ -149,9 +148,9 @@ class DataManipulationBot(DatabaseClient):
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing data manipulation")
# perform the attack
if not self.connected:
if not len(self.connections):
self.connect()
if self.connected:
if len(self.connections):
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
@@ -177,9 +176,9 @@ class DataManipulationBot(DatabaseClient):
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
if not self._can_perform_action():
return
if self.server_ip_address and self.payload and self.operating_state:
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self._logon()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)

View File

@@ -0,0 +1,192 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class DoSAttackStage(IntEnum):
"""Enum representing the different stages of a Denial of Service attack."""
NOT_STARTED = 0
"Attack not yet started."
PORT_SCAN = 1
"Attack is in discovery stage - checking if provided ip and port are open."
ATTACKING = 2
"Denial of Service attack is in progress."
COMPLETED = 3
"Attack is completed."
class DoSBot(DatabaseClient, Application):
"""A bot that simulates a Denial of Service attack."""
target_ip_address: Optional[IPv4Address] = None
"""IP address of the target service."""
target_port: Optional[Port] = None
"""Port of the target service."""
payload: Optional[str] = None
"""Payload to deliver to the target service as part of the denial of service attack."""
repeat: bool = False
"""If true, the Denial of Service bot will keep performing the attack."""
attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED
"""Current stage of the DoS kill chain."""
port_scan_p_of_success: float = 0.1
"""Probability of port scanning being sucessful."""
dos_intensity: float = 1.0
"""How much of the max sessions will be used by the DoS when attacking."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
def set_original_state(self):
"""Set the original state of the Denial of Service Bot."""
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"target_ip_address",
"target_port",
"payload",
"repeat",
"attack_stage",
"max_sessions",
"port_scan_p_of_success",
"dos_intensity",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
return rm
def configure(
self,
target_ip_address: IPv4Address,
target_port: Optional[Port] = Port.POSTGRES_SERVER,
payload: Optional[str] = None,
repeat: bool = False,
port_scan_p_of_success: float = 0.1,
dos_intensity: float = 1.0,
max_sessions: int = 1000,
):
"""
Configure the Denial of Service bot.
:param: target_ip_address: The IP address of the Node containing the target service.
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
:param: dos_intensity: The intensity of the DoS attack.
Multiplied with the application's max session - Default is 1.0
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
"""
self.target_ip_address = target_ip_address
self.target_port = target_port
self.payload = payload
self.repeat = repeat
self.port_scan_p_of_success = port_scan_p_of_success
self.dos_intensity = dos_intensity
self.max_sessions = max_sessions
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
)
def run(self):
"""Run the Denial of Service Bot."""
super().run()
self._application_loop()
def _application_loop(self):
"""
The main application loop for the Denial of Service bot.
The loop goes through the stages of a DoS attack.
"""
if not self._can_perform_action():
return
# DoS bot cannot do anything without a target
if not self.target_ip_address or not self.target_port:
self.sys_log.error(
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
)
return
self.clear_connections()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
self._perform_dos()
if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING:
self.attack_stage = DoSAttackStage.NOT_STARTED
else:
self.attack_stage = DoSAttackStage.COMPLETED
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PORT_SCAN` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == DoSAttackStage.NOT_STARTED:
# perform a port scan to identify that the SQL port is open on the server
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the port scan
port_is_open = True # Temporary; later we can implement NMAP port scan.
if port_is_open:
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DoSAttackStage.PORT_SCAN
def _perform_dos(self):
"""
Perform the Denial of Service attack.
DoSBot does this by clogging up the available connections to a service.
"""
if not self.attack_stage == DoSAttackStage.PORT_SCAN:
return
self.attack_stage = DoSAttackStage.ATTACKING
self.server_ip_address = self.target_ip_address
self.port = self.target_port
dos_sessions = int(float(self.max_sessions) * self.dos_intensity)
for i in range(dos_sessions):
self.connect()
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, iterate through the application loop.
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()

View File

@@ -41,5 +41,5 @@ class Process(Software):
:rtype: Dict
"""
state = super().describe_state()
state.update({"operating_state": self.operating_state.name})
state.update({"operating_state": self.operating_state.value})
return state

View File

@@ -1,4 +1,3 @@
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
@@ -22,7 +21,6 @@ class DatabaseService(Service):
"""
password: Optional[str] = None
connections: Dict[str, datetime] = {}
backup_server: IPv4Address = None
"""IP address of the backup server."""
@@ -47,7 +45,7 @@ class DatabaseService(Service):
super().set_original_state()
vals_to_include = {
"password",
"connections",
"_connections",
"backup_server",
"latest_backup_directory",
"latest_backup_file_name",
@@ -57,7 +55,7 @@ class DatabaseService(Service):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
self.clear_connections()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
@@ -140,24 +138,39 @@ class DatabaseService(Service):
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
self, connection_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
)
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
# try to create connection
if not self.add_connection(connection_id=connection_id):
status_code = 500
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
return {
"status_code": status_code,
"type": "connect_response",
"response": status_code == 200,
"connection_id": connection_id,
}
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
def _process_sql(
self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None
) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
@@ -169,15 +182,28 @@ class DatabaseService(Service):
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
return {
"status_code": 200,
"type": "sql",
"data": True,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.health_state_actual = SoftwareHealthState.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
else:
@@ -203,19 +229,25 @@ class DatabaseService(Service):
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
result = {"status_code": 500, "data": []}
# if server service is down, return error
if not self._can_perform_action():
return False
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
result = self._process_connect(session_id=session_id, password=payload.get("password"))
result = self._process_connect(
connection_id=payload.get("connection_id"), password=payload.get("password")
)
elif payload["type"] == "disconnect":
if session_id in self.connections:
self.connections.pop(session_id)
if payload["connection_id"] in self.connections:
self.remove_connection(connection_id=payload["connection_id"])
elif payload["type"] == "sql":
if session_id in self.connections:
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
if payload.get("connection_id") in self.connections:
result = self._process_sql(
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
)
else:
result = {"status_code": 401, "type": "sql"}
self.send(payload=result, session_id=session_id)

View File

@@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC):
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
connected: bool = False
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPClient"
kwargs["port"] = Port.FTP
@@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC):
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
)
if payload.status_code == FTPStatusCode.OK:
self.connected = False
return True
return False
return payload.status_code == FTPStatusCode.OK
def send_file(
self,
@@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC):
return False
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
if not len(self.connections):
return False
else:
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
@@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC):
:type: dest_port: Optional[Port]
"""
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
if not len(self.connections):
return False
else:
# send retrieve request
@@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC):
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
return False
# if PORT succeeded, add the connection as an active connection list
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
self.add_connection(connection_id=session_id, session_id=session_id)
# if QUIT succeeded, remove the session from active connection list
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
self.remove_connection(connection_id=session_id)
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
self._process_ftp_command(payload=payload, session_id=session_id)

View File

@@ -1,5 +1,4 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
@@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC):
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
connections: Dict[str, IPv4Address] = {}
"""Current active connections to the FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = Port.FTP
@@ -41,7 +37,7 @@ class FTPServer(FTPServiceABC):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.connections.clear()
self.clear_connections()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
@@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC):
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
if session_id:
session_details = self._get_session_details(session_id)
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
@@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC):
# check that the port is valid
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
# return successful connection
self.connections[session_id] = session_details.with_ip_address
self.add_connection(connection_id=session_id, session_id=session_id)
payload.status_code = FTPStatusCode.OK
return payload
@@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC):
return payload
if payload.ftp_command == FTPCommand.QUIT:
self.connections.pop(session_id)
self.remove_connection(connection_id=session_id)
payload.status_code = FTPStatusCode.OK
return payload

View File

@@ -0,0 +1,132 @@
from datetime import datetime
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service, ServiceOperatingState
_LOGGER = getLogger(__name__)
class NTPClient(Service):
"""Represents a NTP client as a service."""
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None
def __init__(self, **kwargs):
kwargs["name"] = "NTPClient"
kwargs["port"] = Port.NTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def configure(self, ntp_server_ip_address: IPv4Address) -> None:
"""
Set the IP address for the NTP server.
:param ntp_server_ip_address: IPv4 address of NTP server.
:param ntp_client_ip_Address: IPv4 address of NTP client.
"""
self.ntp_server = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state
of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def send(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
dest_ip_address: IPv4Address = None,
dest_port: [Port] = Port.NTP,
**kwargs,
) -> bool:
"""Requests NTP data from NTP server.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:return: True if successful, False otherwise.
"""
return super().send(
payload=payload,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
session_id=session_id,
**kwargs,
)
def receive(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""Receives time data from server.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
if not isinstance(payload, NTPPacket):
_LOGGER.debug(f"{payload} is not a NTPPacket")
return False
if payload.ntp_reply.ntp_datetime:
self.sys_log.info(
f"{self.name}: \
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
)
self.time = payload.ntp_reply.ntp_datetime
return True
def request_time(self) -> None:
"""Send request to ntp_server."""
ntp_server_packet = NTPPacket()
self.send(payload=ntp_server_packet, dest_ip_address=self.ntp_server)
def apply_timestep(self, timestep: int) -> None:
"""
For each timestep request the time from the NTP server.
In this instance, if any multi-timestep processes are currently
occurring (such as restarting or installation), then they are brought one step closer to
being finished.
:param timestep: The current timestep number. (Amount of time since simulation episode began)
:type timestep: int
"""
self.sys_log.info(f"{self.name} apply_timestep")
super().apply_timestep(timestep)
if self.operating_state == ServiceOperatingState.RUNNING:
# request time from server
self.request_time()
else:
self.sys_log.debug(f"{self.name} ntp client not running")

View File

@@ -0,0 +1,73 @@
from datetime import datetime
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class NTPServer(Service):
"""Represents a NTP server as a service."""
def __init__(self, **kwargs):
kwargs["name"] = "NTPServer"
kwargs["port"] = Port.NTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current
state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including
resetting any stateful properties or statistics, and clearing any message
queues.
"""
pass
def receive(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a request from NTPClient.
Check that request has a valid IP address.
:param payload: The payload to send.
:param session_id: Id of the session (Optional).
:return: True if valid NTP request else False.
"""
if not (isinstance(payload, NTPPacket)):
_LOGGER.debug(f"{payload} is not a NTPPacket")
return False
payload: NTPPacket = payload
# generate a reply with the current time
time = datetime.now()
payload = payload.generate_reply(time)
# send reply
self.send(payload, session_id)
return True

View File

@@ -40,6 +40,12 @@ class Service(IOSoftware):
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def _can_perform_action(self) -> bool:
"""
Checks if the service can perform actions.
@@ -52,7 +58,7 @@ class Service(IOSoftware):
if not super()._can_perform_action():
return False
if self.operating_state is not self.operating_state.RUNNING:
if self.operating_state is not ServiceOperatingState.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
@@ -74,12 +80,6 @@ class Service(IOSoftware):
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()

View File

@@ -1,4 +1,6 @@
import copy
from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
@@ -231,7 +233,7 @@ class IOSoftware(Software):
installing_count: int = 0
"The number of times the software has been installed. Default is 0."
max_sessions: int = 1
max_sessions: int = 100
"The maximum number of sessions that the software can handle simultaneously. Default is 0."
tcp: bool = True
"Indicates if the software uses TCP protocol for communication. Default is True."
@@ -239,6 +241,8 @@ class IOSoftware(Software):
"Indicates if the software uses UDP protocol for communication. Default is True."
port: Port
"The port to which the software is connected."
_connections: Dict[str, Dict] = {}
"Active connections."
def set_original_state(self):
"""Sets the original state."""
@@ -283,6 +287,65 @@ class IOSoftware(Software):
return False
return True
@property
def connections(self) -> Dict[str, Dict]:
"""Return the public version of connections."""
return copy.copy(self._connections)
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
"""
Create a new connection to this service.
Returns true if connection successfully created
:param: connection_id: UUID of the connection to create
:type: string
"""
# if over or at capacity, set to overwhelmed
if len(self._connections) >= self.max_sessions:
self.health_state_actual = SoftwareHealthState.OVERWHELMED
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
return False
else:
# if service was previously overwhelmed, set to good because there is enough space for connections
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.health_state_actual = SoftwareHealthState.GOOD
# check that connection already doesn't exist
if not self._connections.get(connection_id):
session_details = None
if session_id:
session_details = self._get_session_details(session_id)
self._connections[connection_id] = {
"ip_address": session_details.with_ip_address if session_details else None,
"time": datetime.now(),
}
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
return True
# connection with given id already exists
self.sys_log.error(
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
)
return False
def remove_connection(self, connection_id: str) -> bool:
"""
Remove a connection from this service.
Returns true if connection successfully removed
:param: connection_id: UUID of the connection to create
:type: string
"""
if self.connections.get(connection_id):
self._connections.pop(connection_id)
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
return True
def clear_connections(self):
"""Clears all the connections from the software."""
self._connections = {}
def send(
self,
payload: Any,

View File

@@ -93,25 +93,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -126,7 +126,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -497,7 +497,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -505,8 +505,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -31,13 +31,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -104,25 +97,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,7 +130,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -508,7 +501,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -516,8 +509,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -37,13 +37,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -111,25 +104,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -144,7 +137,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -515,7 +508,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -523,8 +516,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
@@ -542,25 +535,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -575,7 +568,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -946,7 +939,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -954,8 +947,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -35,13 +35,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -109,25 +102,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
# - service_name: backup_service
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -142,7 +135,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -513,7 +506,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -521,8 +514,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -105,25 +105,23 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +136,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -509,7 +507,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -517,8 +515,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_service
agent_settings:

View File

@@ -1,8 +1,8 @@
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_data_manipulation(uc2_network):

View File

@@ -14,7 +14,7 @@ def test_file_observation():
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})

View File

@@ -0,0 +1,180 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]:
computer, server = client_server
# Install DoSBot on computer
computer.software_manager.install(DoSBot)
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
target_port=Port.POSTGRES_SERVER,
)
# Install DB Server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
return dos_bot, computer, db_server_service, server
@pytest.fixture(scope="function")
def dos_bot_db_server_green_client(example_network) -> Network:
network: Network = example_network
router_1: Router = example_network.get_node_by_hostname("router_1")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
client_1: Computer = network.get_node_by_hostname("client_1")
client_2: Computer = network.get_node_by_hostname("client_2")
server: Server = network.get_node_by_hostname("server_1")
# install DoS bot on client 1
client_1.software_manager.install(DoSBot)
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
dos_bot.configure(
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
target_port=Port.POSTGRES_SERVER,
)
# install db server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
# Install DB client (green) on client 2
client_2.software_manager.install(DatabaseClient)
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
database_client.run()
return network
def test_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = True
dos_bot.run()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
db_server_service.clear_connections()
db_server_service.health_state_actual = SoftwareHealthState.GOOD
assert len(db_server_service.connections) == 0
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
def test_non_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = False
dos_bot.run()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
db_server_service.clear_connections()
db_server_service.health_state_actual = SoftwareHealthState.GOOD
assert len(db_server_service.connections) == 0
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert len(dos_bot.connections) == 0
assert len(db_server_service.connections) == 0
assert len(dos_bot.connections) == 0
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
def test_dos_bot_database_service_connection(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
dos_bot.operating_state = ApplicationOperatingState.RUNNING
dos_bot.attack_stage = DoSAttackStage.PORT_SCAN
dos_bot._perform_dos()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client):
network: Network = dos_bot_db_server_green_client
client_1: Computer = network.get_node_by_hostname("client_1")
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
client_2: Computer = network.get_node_by_hostname("client_2")
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
server: Server = network.get_node_by_hostname("server_1")
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = False
dos_bot.run()
# DoS bot fills up connection of db server service
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(green_db_client.connections) == 0
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
# db server service is overwhelmed
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
# green agent tries to connect but fails because service is overwhelmed
assert green_db_client.connect() is False
assert len(green_db_client.connections) == 0

View File

@@ -1,6 +1,9 @@
from ipaddress import IPv4Address
from typing import Tuple
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
import pytest
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
@@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
@pytest.fixture(scope="function")
def peer_to_peer() -> Tuple[Node, Node]:
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
node_a.connect_nic(nic_a)
node_a.software_manager.get_open_ports()
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
Link(endpoint_a=nic_a, endpoint_b=nic_b)
assert node_a.ping("192.168.0.11")
node_a.software_manager.install(DatabaseClient)
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
node_a.software_manager.software["DatabaseClient"].run()
node_b.software_manager.install(DatabaseService)
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
database_service.start()
return node_a, node_b
@pytest.fixture(scope="function")
def peer_to_peer_secure_db() -> Tuple[Node, Node]:
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
node_a.connect_nic(nic_a)
node_a.software_manager.get_open_ports()
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
Link(endpoint_a=nic_a, endpoint_b=nic_b)
assert node_a.ping("192.168.0.11")
node_a.software_manager.install(DatabaseClient)
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
node_a.software_manager.software["DatabaseClient"].run()
node_b.software_manager.install(DatabaseService)
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
database_service.password = "12345"
database_service.start()
return node_a, node_b
def test_database_client_server_connection(peer_to_peer):
node_a, node_b = peer_to_peer
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
db_client.connect()
assert len(db_client.connections) == 1
assert len(db_service.connections) == 1
db_client.disconnect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
def test_database_client_server_correct_password(peer_to_peer_secure_db):
node_a, node_b = peer_to_peer_secure_db
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
db_service.password = "12345"
assert db_client.connect()
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345")
db_client.connect()
assert len(db_client.connections) == 1
assert len(db_service.connections) == 1
def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
def test_database_client_server_incorrect_password(peer_to_peer_secure_db):
node_a, node_b = peer_to_peer_secure_db
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
db_service.password = "12345"
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
assert not db_client.connect()
# should fail
db_client.connect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass")
db_client.connect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT")
@@ -66,13 +121,13 @@ def test_database_client_query(uc2_network):
def test_create_database_backup(uc2_network):
"""Run the backup_database method and check if the FTP server has the relevant file."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
# back up should be created
assert db_service.backup_database() is True
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
# backup file should exist in the backup server
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
@@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network):
def test_restore_backup(uc2_network):
"""Run the restore_backup method and check if the backup is properly restored."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
# create a back up
assert db_service.backup_database() is True
@@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
assert len(db_client.connections)
assert db_client.query("SELECT") is True

View File

@@ -0,0 +1,86 @@
from ipaddress import IPv4Address
from time import sleep
from typing import Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
# Create simple network for testing
# Define one node to be an NTP server and another node to be a NTP Client.
@pytest.fixture(scope="function")
def create_ntp_network(client_server) -> Tuple[NTPClient, Computer, NTPServer, Server]:
"""
+------------+ +------------+
| ntp | | ntp |
| client_1 +------------+ server_1 |
| | | |
+------------+ +------------+
"""
client, server = client_server
server.power_on()
server.software_manager.install(NTPServer)
ntp_server: NTPServer = server.software_manager.software.get("NTPServer")
ntp_server.start()
client.power_on()
client.software_manager.install(NTPClient)
ntp_client: NTPClient = client.software_manager.software.get("NTPClient")
ntp_client.start()
return ntp_client, client, ntp_server, server
def test_ntp_client_server(create_ntp_network):
ntp_client, client, ntp_server, server = create_ntp_network
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
assert ntp_client.time is None
ntp_client.request_time()
assert ntp_client.time is not None
first_time = ntp_client.time
sleep(0.1)
ntp_client.apply_timestep(1) # Check time advances
second_time = ntp_client.time
assert first_time < second_time
# Test ntp client behaviour when ntp server is unavailable.
def test_ntp_server_failure(create_ntp_network):
ntp_client, client, ntp_server, server = create_ntp_network
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
# Turn off ntp server.
ntp_server.stop()
assert ntp_server.operating_state == ServiceOperatingState.STOPPED
# And request a time update.
ntp_client.request_time()
assert ntp_client.time is None
# Restart ntp server.
ntp_server.start()
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
ntp_client.request_time()
assert ntp_client.time is not None

View File

@@ -5,7 +5,7 @@ from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
DataManipulationAttackStage,
DataManipulationBot,
)
@@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert dm_bot.connected
assert len(dm_bot.connections)

View File

@@ -0,0 +1,90 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
@pytest.fixture(scope="function")
def dos_bot() -> DoSBot:
computer = Computer(
hostname="compromised_pc",
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
operating_state=NodeOperatingState.ON,
)
computer.software_manager.install(DoSBot)
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
dos_bot.set_original_state()
return dos_bot
def test_dos_bot_creation(dos_bot):
"""Test that the DoS bot is installed on a node."""
assert dos_bot is not None
def test_dos_bot_reset(dos_bot):
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
# should reset the relevant items
dos_bot.reset_component_for_episode(episode=0)
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
dos_bot.set_original_state()
dos_bot.reset_component_for_episode(episode=1)
# should reset to the configured value
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
assert dos_bot.target_port is Port.HTTP
assert dos_bot.payload == "payload"
assert dos_bot.repeat is True
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
dos_bot_node: Computer = dos_bot.parent
assert dos_bot_node.operating_state is NodeOperatingState.ON
dos_bot_node.power_off()
for i in range(dos_bot_node.shut_down_duration + 1):
dos_bot_node.apply_timestep(timestep=i)
assert dos_bot_node.operating_state is NodeOperatingState.OFF
dos_bot._application_loop()
# assert not run
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
def test_dos_bot_not_configured(dos_bot):
dos_bot.target_ip_address = None
dos_bot.operating_state = ApplicationOperatingState.RUNNING
dos_bot._application_loop()
def test_dos_bot_perform_port_scan(dos_bot):
dos_bot._perform_port_scan(p_of_success=1)
assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN

View File

@@ -1,5 +1,6 @@
from ipaddress import IPv4Address
from typing import Tuple, Union
from uuid import uuid4
import pytest
@@ -62,18 +63,26 @@ def test_disconnect_when_client_is_closed(database_client_on_computer):
def test_disconnect(database_client_on_computer):
"""Database client should set connected to False and remove the database server ip address."""
"""Database client should remove the connection."""
database_client, computer = database_client_on_computer
database_client.connected = True
database_client._connections[str(uuid4())] = {"item": True}
assert len(database_client.connections) == 1
assert database_client.operating_state is ApplicationOperatingState.RUNNING
assert database_client.server_ip_address is not None
database_client.disconnect()
assert database_client.connected is False
assert database_client.server_ip_address is None
assert len(database_client.connections) == 0
uuid = str(uuid4())
database_client._connections[uuid] = {"item": True}
assert len(database_client.connections) == 1
database_client.disconnect(connection_id=uuid)
assert len(database_client.connections) == 0
def test_query_when_client_is_closed(database_client_on_computer):
@@ -86,19 +95,6 @@ def test_query_when_client_is_closed(database_client_on_computer):
assert database_client.query(sql="test") is False
def test_query_failed_reattempt(database_client_on_computer):
"""Database client query should return False if the reattempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test", is_reattempt=True) is False
def test_query_fail_to_connect(database_client_on_computer):
"""Database client query should return False if the connect attempt fails."""
database_client, computer = database_client_on_computer

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
@@ -66,3 +68,34 @@ def test_enable_disable(service):
service.enable()
assert service.operating_state == ServiceOperatingState.STOPPED
def test_overwhelm_service(service):
service.max_sessions = 2
service.start()
uuid = str(uuid4())
assert service.add_connection(connection_id=uuid) # should be true
assert service.health_state_actual is SoftwareHealthState.GOOD
assert not service.add_connection(connection_id=uuid) # fails because connection already exists
assert service.health_state_actual is SoftwareHealthState.GOOD
assert service.add_connection(connection_id=str(uuid4())) # succeed
assert service.health_state_actual is SoftwareHealthState.GOOD
assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
def test_create_and_remove_connections(service):
service.start()
uuid = str(uuid4())
assert service.add_connection(connection_id=uuid) # should be true
assert len(service.connections) == 1
assert service.health_state_actual is SoftwareHealthState.GOOD
assert service.remove_connection(connection_id=uuid) # should be true
assert len(service.connections) == 0
assert service.health_state_actual is SoftwareHealthState.GOOD