Merge remote-tracking branch 'origin/dev' into dev-v3.0.0b6
This commit is contained in:
@@ -18,36 +18,43 @@ parameters:
|
||||
py: '3.8'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'UbuntuPython310'
|
||||
py: '3.10'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: true
|
||||
publish_coverage: true
|
||||
- job_name: 'WindowsPython38'
|
||||
py: '3.8'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'WindowsPython310'
|
||||
py: '3.10'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython38'
|
||||
py: '3.8'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython310'
|
||||
py: '3.10'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
|
||||
stages:
|
||||
- stage: Test
|
||||
jobs:
|
||||
- ${{ each item in parameters.matrix }}:
|
||||
- job: ${{ item.job_name }}
|
||||
timeoutInMinutes: 90
|
||||
cancelTimeoutInMinutes: 1
|
||||
pool:
|
||||
vmImage: ${{ item.img }}
|
||||
|
||||
condition: or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} )
|
||||
condition: and(succeeded(), or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} ))
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
@@ -109,12 +116,12 @@ stages:
|
||||
|
||||
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
|
||||
# publish the html report - so we can debug the coverage if needed
|
||||
condition: ${{ item.every_time }} # should only be run once
|
||||
condition: ${{ item.publish_coverage }} # should only be run once
|
||||
artifact: coverage_report
|
||||
|
||||
- task: PublishCodeCoverageResults@2
|
||||
# publish the code coverage so it can be viewed in the run coverage page
|
||||
condition: ${{ item.every_time }} # should only be run once
|
||||
condition: ${{ item.publish_coverage }} # should only be run once
|
||||
inputs:
|
||||
codeCoverageTool: Cobertura
|
||||
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -157,3 +157,4 @@ benchmark/output
|
||||
src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
sandbox/
|
||||
sandbox.ipynb
|
||||
|
||||
13
CHANGELOG.md
13
CHANGELOG.md
@@ -55,6 +55,19 @@ SessionManager.
|
||||
- FTP Services: `FTPClient` and `FTPServer`
|
||||
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
|
||||
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
|
||||
- NTP Services: `NTPClient` and `NTPServer`
|
||||
- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic.
|
||||
- **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required.
|
||||
- **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance.
|
||||
- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework.
|
||||
|
||||
### Changed
|
||||
- Integrated the RouteTable into the Routers frame processing.
|
||||
- Frames are now dropped when their TTL reaches 0
|
||||
- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts.
|
||||
- **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting.
|
||||
- **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios.
|
||||
|
||||
|
||||
### Removed
|
||||
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
NTP Client Server
|
||||
=================
|
||||
|
||||
NTP Server
|
||||
----------
|
||||
The ``NTPServer`` provides a NTP Server simulation by extending the base Service class.
|
||||
|
||||
NTP Client
|
||||
----------
|
||||
The ``NTPClient`` provides a NTP Client simulation by extending the base Service class.
|
||||
|
||||
Key capabilities
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
- Simulates NTP requests and NTPPacket transfer across a network
|
||||
- Leverages the Service base class for install/uninstall, status tracking, etc.
|
||||
|
||||
Usage
|
||||
^^^^^
|
||||
- Install on a Node via the ``SoftwareManager`` to start the database service.
|
||||
- Service runs on TCP port 123 by default.
|
||||
|
||||
Implementation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- NTP request and responses use a ``NTPPacket`` object
|
||||
- Extends Service class for integration with ``SoftwareManager``.
|
||||
|
||||
NTP Client
|
||||
----------
|
||||
|
||||
The NTPClient provides a client interface for connecting to the ``NTPServer``.
|
||||
|
||||
Key features
|
||||
^^^^^^^^^^^^
|
||||
|
||||
- Connects to the ``NTPServer`` via the ``SoftwareManager``.
|
||||
|
||||
Usage
|
||||
^^^^^
|
||||
|
||||
- Install on a Node via the ``SoftwareManager`` to start the database service.
|
||||
- Service runs on TCP port 123 by default.
|
||||
|
||||
Implementation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- Leverages ``SoftwareManager`` for sending payloads over the network.
|
||||
- Provides easy interface for Nodes to find IP addresses via domain names.
|
||||
- Extends base Service class.
|
||||
@@ -107,23 +107,23 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_web_service
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_web_service
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -138,7 +138,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -528,7 +528,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -536,8 +536,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -627,7 +627,7 @@ class ActionManager:
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
@@ -45,6 +45,7 @@ class AgentSettings(BaseModel):
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
@@ -180,7 +181,7 @@ class ProxyAgent(AbstractAgent):
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -41,8 +41,7 @@ class AbstractObservation(ABC):
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
|
||||
a subclass of this class may need to translate from a 'reference' to a UUID.
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -54,12 +53,12 @@ class FileObservation(AbstractObservation):
|
||||
"""
|
||||
Initialise file observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -121,7 +120,7 @@ class ServiceObservation(AbstractObservation):
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
|
||||
`['network','nodes',<node_hostname>,'services', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -166,7 +165,7 @@ class ServiceObservation(AbstractObservation):
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
|
||||
return cls(where=parent_where + ["services", config["service_name"]])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
@@ -183,7 +182,7 @@ class LinkObservation(AbstractObservation):
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -249,7 +248,7 @@ class FolderObservation(AbstractObservation):
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
@@ -328,7 +327,7 @@ class FolderObservation(AbstractObservation):
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_uuid>,'file_system']
|
||||
['network','nodes',<node_hostname>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
@@ -354,7 +353,7 @@ class NicObservation(AbstractObservation):
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
@@ -391,12 +390,12 @@ class NicObservation(AbstractObservation):
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
|
||||
return cls(where=parent_where + ["NICs", config["nic_num"]])
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
@@ -419,9 +418,9 @@ class NodeObservation(AbstractObservation):
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
|
||||
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service UUID, defaults to {}
|
||||
:param services: Mapping between position in observation space and service name, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
@@ -430,7 +429,7 @@ class NodeObservation(AbstractObservation):
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
|
||||
:param nics: Mapping between position in observation space and NIC idx, defaults to {}
|
||||
:type nics: Dict[int,str], optional
|
||||
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
@@ -548,11 +547,11 @@ class NodeObservation(AbstractObservation):
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_uuid = game.ref_map_nodes[config["node_ref"]]
|
||||
node_hostname = config["node_hostname"]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
where = ["network", "nodes", node_hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
where = parent_where + ["nodes", node_hostname]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
@@ -563,8 +562,8 @@ class NodeObservation(AbstractObservation):
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
|
||||
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
|
||||
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
@@ -605,7 +604,7 @@ class AclObservation(AbstractObservation):
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_uuid>,'acl','acl']
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
@@ -732,12 +731,12 @@ class AclObservation(AbstractObservation):
|
||||
nic_obj = node_obj.ethernet_port[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
@@ -867,6 +866,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
|
||||
@@ -82,11 +82,11 @@ class DummyReward(AbstractReward):
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
|
||||
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the database file.
|
||||
:type node_uuid: str
|
||||
:param node_hostname: Hostname of the node which contains the database file.
|
||||
:type node_hostname: str
|
||||
:param folder_name: folder which contains the database file.
|
||||
:type folder_name: str
|
||||
:param file_name: name of the database file.
|
||||
@@ -95,7 +95,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_uuid,
|
||||
node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
@@ -136,49 +136,29 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
node_hostname = config.get("node_hostname")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not node_ref:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not folder_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not file_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(
|
||||
(
|
||||
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
|
||||
f"found in the simulation"
|
||||
)
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not (node_hostname and folder_name and file_name):
|
||||
msg = f"{cls.__name__} could not be initialised with parameters {config}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
|
||||
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_uuid: str, service_uuid: str) -> None:
|
||||
def __init__(self, node_hostname: str, service_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the web server service.
|
||||
:type node_uuid: str
|
||||
:param service_uuid: UUID of the web server service.
|
||||
:type service_uuid: str
|
||||
:param node_hostname: Hostname of the node which contains the web server service.
|
||||
:type node_hostname: str
|
||||
:param service_name: Name of the web server service.
|
||||
:type service_name: str
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -209,26 +189,17 @@ class WebServer404Penalty(AbstractReward):
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
service_ref = config.get("service_ref")
|
||||
if not (node_ref and service_ref):
|
||||
node_hostname = config.get("node_hostname")
|
||||
service_name = config.get("service_name")
|
||||
if not (node_hostname and service_name):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
service_uuid = game.ref_map_services[service_ref]
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
" found in the simulator."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
|
||||
return cls(node_hostname=node_hostname, service_name=service_name)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
|
||||
@@ -18,13 +18,15 @@ from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -244,6 +246,8 @@ class PrimaiteGame:
|
||||
"WebServer": WebServer,
|
||||
"FTPClient": FTPClient,
|
||||
"FTPServer": FTPServer,
|
||||
"NTPClient": NTPClient,
|
||||
"NTPServer": NTPServer,
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
@@ -292,6 +296,7 @@ class PrimaiteGame:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload"),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
|
||||
@@ -102,7 +102,7 @@ class DomainController(SimComponent):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state.update({"accounts": {uuid: acct.describe_state() for uuid, acct in self.accounts.items()}})
|
||||
state.update({"accounts": {acct.username: acct.describe_state() for acct in self.accounts.values()}})
|
||||
return state
|
||||
|
||||
def _register_account(self, account: Account) -> None:
|
||||
|
||||
@@ -199,10 +199,24 @@ class Network(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
|
||||
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
|
||||
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
|
||||
"links": {},
|
||||
}
|
||||
)
|
||||
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
|
||||
for uuid, link in self.links.items():
|
||||
node_a = link.endpoint_a._connected_node
|
||||
node_b = link.endpoint_b._connected_node
|
||||
hostname_a = node_a.hostname if node_a else None
|
||||
hostname_b = node_b.hostname if node_b else None
|
||||
port_a = link.endpoint_a._port_num_on_node
|
||||
port_b = link.endpoint_b._port_num_on_node
|
||||
state["links"][uuid] = link.describe_state()
|
||||
state["links"][uuid]["hostname_a"] = hostname_a
|
||||
state["links"][uuid]["hostname_b"] = hostname_b
|
||||
state["links"][uuid]["port_a"] = port_a
|
||||
state["links"][uuid]["port_b"] = port_b
|
||||
|
||||
return state
|
||||
|
||||
def add_node(self, node: Node) -> None:
|
||||
|
||||
148
src/primaite/simulator/network/creation.py
Normal file
148
src/primaite/simulator/network/creation.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
Each switch is assumed to have one port reserved for connecting to a router, reducing the effective
|
||||
number of ports available for PCs. The function calculates the total number of switches needed
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
Example:
|
||||
>>> num_of_switches_required(5)
|
||||
1
|
||||
>>> num_of_switches_required(24,24)
|
||||
2
|
||||
>>> num_of_switches_required(48,24)
|
||||
3
|
||||
>>> num_of_switches_required(25,10)
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
|
||||
|
||||
def create_office_lan(
|
||||
lan_name: str,
|
||||
subnet_base: int,
|
||||
pcs_ip_block_start: int,
|
||||
num_pcs: int,
|
||||
network: Optional[Network] = None,
|
||||
include_router: bool = True,
|
||||
) -> Network:
|
||||
"""
|
||||
Creates a 2-Tier or 3-Tier office local area network (LAN).
|
||||
|
||||
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
|
||||
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
|
||||
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
|
||||
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
|
||||
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
|
||||
|
||||
|
||||
:param str lan_name: The name to be assigned to the LAN.
|
||||
:param int subnet_base: The subnet base number to be used in the IP addresses.
|
||||
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
|
||||
:param int num_pcs: The number of PCs to be added to the LAN.
|
||||
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
|
||||
created.
|
||||
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
|
||||
:return: The network object with the LAN components added.
|
||||
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
|
||||
"""
|
||||
# Initialise the network if not provided
|
||||
if not network:
|
||||
network = Network()
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
# Create a core switch if more than one edge switch is needed
|
||||
if num_of_switches > 1:
|
||||
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
|
||||
core_switch.power_on()
|
||||
network.add_node(core_switch)
|
||||
core_switch_port = 1
|
||||
|
||||
# Initialise the default gateway to None
|
||||
default_gateway = None
|
||||
|
||||
# Optionally include a router in the LAN
|
||||
if include_router:
|
||||
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
|
||||
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
network.add_node(router)
|
||||
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
|
||||
router.enable_port(1)
|
||||
|
||||
# Initialise the first edge switch and connect to the router or core switch
|
||||
switch_port = 0
|
||||
switch_n = 1
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
if num_of_switches > 1:
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
# Connect the new switch to the router or core switch
|
||||
if num_of_switches > 1:
|
||||
core_switch_port += 1
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Create and add a PC to the network
|
||||
pc = Computer(
|
||||
hostname=f"pc_{i}_{lan_name}",
|
||||
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway=default_gateway,
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc.power_on()
|
||||
network.add_node(pc)
|
||||
|
||||
# Connect the PC to the switch
|
||||
switch_port += 1
|
||||
network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1])
|
||||
switch.switch_ports[switch_port].enable()
|
||||
|
||||
return network
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import secrets
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -91,6 +91,8 @@ class NIC(SimComponent):
|
||||
"Indicates if the NIC supports Wake-on-LAN functionality."
|
||||
_connected_node: Optional[Node] = None
|
||||
"The Node to which the NIC is connected."
|
||||
_port_num_on_node: Optional[int] = None
|
||||
"Which port number is assigned on this NIC"
|
||||
_connected_link: Optional[Link] = None
|
||||
"The Link to which the NIC is connected."
|
||||
enabled: bool = False
|
||||
@@ -148,7 +150,7 @@ class NIC(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"ip_adress": str(self.ip_address),
|
||||
"ip_address": str(self.ip_address),
|
||||
"subnet_mask": str(self.subnet_mask),
|
||||
"mac_address": self.mac_address,
|
||||
"speed": self.speed,
|
||||
@@ -272,18 +274,40 @@ class NIC(SimComponent):
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive a network frame from the connected link if the NIC is enabled.
|
||||
Receive a network frame from the connected link, processing it if the NIC is enabled.
|
||||
|
||||
The Frame is passed to the Node.
|
||||
This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks
|
||||
if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the
|
||||
connected node. The method also handles the discarding of frames with TTL expired and logs this event.
|
||||
|
||||
:param frame: The network frame being received.
|
||||
The frame's reception is based on various conditions:
|
||||
- If the NIC is disabled, the frame is not processed.
|
||||
- If the TTL of the frame reaches zero after decrement, it is discarded and logged.
|
||||
- If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted.
|
||||
- All other frames are dropped and logged or printed to the console.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
accept_frame = False
|
||||
|
||||
# Check if it's a broadcast:
|
||||
if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}:
|
||||
accept_frame = True
|
||||
else:
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address:
|
||||
accept_frame = True
|
||||
|
||||
if accept_frame:
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
@@ -311,6 +335,8 @@ class SwitchPort(SimComponent):
|
||||
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
|
||||
_connected_node: Optional[Node] = None
|
||||
"The Node to which the SwitchPort is connected."
|
||||
_port_num_on_node: Optional[int] = None
|
||||
"The port num on the connected node."
|
||||
_connected_link: Optional[Link] = None
|
||||
"The Link to which the SwitchPort is connected."
|
||||
enabled: bool = False
|
||||
@@ -432,6 +458,9 @@ class SwitchPort(SimComponent):
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
self.pcap.capture(frame)
|
||||
connected_node: Node = self._connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
@@ -497,8 +526,8 @@ class Link(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"endpoint_a": self.endpoint_a.uuid,
|
||||
"endpoint_b": self.endpoint_b.uuid,
|
||||
"endpoint_a": self.endpoint_a.uuid, # TODO: consider if using UUID is the best way to do this
|
||||
"endpoint_b": self.endpoint_b.uuid, # TODO: consider if using UUID is the best way to do this
|
||||
"bandwidth": self.bandwidth,
|
||||
"current_load": self.current_load,
|
||||
}
|
||||
@@ -667,17 +696,30 @@ class ARPCache:
|
||||
"""Clear the entire ARP cache, removing all stored entries."""
|
||||
self.arp.clear()
|
||||
|
||||
def send_arp_request(self, target_ip_address: Union[IPv4Address, str]):
|
||||
def send_arp_request(
|
||||
self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None
|
||||
):
|
||||
"""
|
||||
Perform a standard ARP request for a given target IP address.
|
||||
|
||||
Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP
|
||||
address.
|
||||
address. This method can be configured to ignore specific networks when sending out ARP requests,
|
||||
which is useful in environments where certain addresses should not be queried.
|
||||
|
||||
:param target_ip_address: The target IP address to send an ARP request for.
|
||||
:param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP
|
||||
request broadcast. Each address in this list indicates a network which will not be queried during the ARP
|
||||
request process. This is particularly useful in complex network environments where traffic should be
|
||||
minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being
|
||||
sent back to their source.
|
||||
"""
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled:
|
||||
use_nic = True
|
||||
if ignore_networks:
|
||||
for ipv4 in ignore_networks:
|
||||
if ipv4 in nic.ip_network:
|
||||
use_nic = False
|
||||
if nic.enabled and use_nic:
|
||||
self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}")
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
@@ -802,7 +844,6 @@ class ICMP:
|
||||
self.arp.send_arp_request(frame.ip.src_ip_address)
|
||||
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
|
||||
return
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
# Network Layer
|
||||
ip_packet = IPPacket(
|
||||
@@ -817,9 +858,7 @@ class ICMP:
|
||||
sequence=frame.icmp.sequence + 1,
|
||||
)
|
||||
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
|
||||
frame = Frame(
|
||||
ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload
|
||||
)
|
||||
frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload)
|
||||
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
|
||||
|
||||
src_nic.send_frame(frame)
|
||||
@@ -1094,12 +1133,12 @@ class Node(SimComponent):
|
||||
{
|
||||
"hostname": self.hostname,
|
||||
"operating_state": self.operating_state.value,
|
||||
"NICs": {uuid: nic.describe_state() for uuid, nic in self.nics.items()},
|
||||
"NICs": {eth_num: nic.describe_state() for eth_num, nic in self.ethernet_port.items()},
|
||||
# "switch_ports": {uuid, sp for uuid, sp in self.switch_ports.items()},
|
||||
"file_system": self.file_system.describe_state(),
|
||||
"applications": {uuid: app.describe_state() for uuid, app in self.applications.items()},
|
||||
"services": {uuid: svc.describe_state() for uuid, svc in self.services.items()},
|
||||
"process": {uuid: proc.describe_state() for uuid, proc in self.processes.items()},
|
||||
"applications": {app.name: app.describe_state() for app in self.applications.values()},
|
||||
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
|
||||
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
}
|
||||
)
|
||||
@@ -1316,6 +1355,7 @@ class Node(SimComponent):
|
||||
self.nics[nic.uuid] = nic
|
||||
self.ethernet_port[len(self.nics)] = nic
|
||||
nic._connected_node = self
|
||||
nic._port_num_on_node = len(self.nics)
|
||||
nic.parent = self
|
||||
self.sys_log.info(f"Connected NIC {nic}")
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
@@ -1442,7 +1482,6 @@ class Node(SimComponent):
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
|
||||
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
@@ -1475,7 +1514,6 @@ class Node(SimComponent):
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.info(f"Added application {application.uuid} to node {self.uuid}")
|
||||
self._application_request_manager.add_request(application.uuid, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
|
||||
@@ -357,11 +357,10 @@ class RouteEntry(SimComponent):
|
||||
"""
|
||||
Represents a single entry in a routing table.
|
||||
|
||||
Attributes:
|
||||
address (IPv4Address): The destination IP address or network address.
|
||||
subnet_mask (IPv4Address): The subnet mask for the network.
|
||||
next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded.
|
||||
metric (int): The cost metric for this route. Default is 0.0.
|
||||
:ivar address: The destination IP address or network address.
|
||||
:ivar subnet_mask: The subnet mask for the network.
|
||||
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
|
||||
:ivar metric: The cost metric for this route. Default is 0.0.
|
||||
|
||||
Example:
|
||||
>>> entry = RouteEntry(
|
||||
@@ -381,12 +380,6 @@ class RouteEntry(SimComponent):
|
||||
metric: float = 0.0
|
||||
"The cost metric for this route. Default is 0.0."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key in {"address", "subnet_mask", "next_hop_ip_address"}:
|
||||
if not isinstance(kwargs[key], IPv4Address):
|
||||
kwargs[key] = IPv4Address(kwargs[key])
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
|
||||
@@ -421,6 +414,7 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
|
||||
routes: List[RouteEntry] = []
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
@@ -465,12 +459,35 @@ class RouteTable(SimComponent):
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
|
||||
"""
|
||||
Sets the next-hop IP address for the default route in a routing table.
|
||||
|
||||
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
|
||||
the method creates a new default route with the specified next-hop IP address. If a default route already
|
||||
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
|
||||
IP address, the method logs this action.
|
||||
|
||||
:param ip_address: The next-hop IP address to be set for the default route.
|
||||
"""
|
||||
if not self.default_route:
|
||||
self.default_route = RouteEntry(
|
||||
ip_address=IPv4Address("0.0.0.0"),
|
||||
subnet_mask=IPv4Address("0.0.0.0"),
|
||||
next_hop_ip_address=ip_address,
|
||||
)
|
||||
else:
|
||||
self.default_route.next_hop_ip_address = ip_address
|
||||
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
|
||||
|
||||
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
|
||||
"""
|
||||
Find the best route for a given destination IP.
|
||||
|
||||
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
|
||||
|
||||
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
|
||||
|
||||
:param destination_ip: The destination IP to find the route for.
|
||||
:return: The best matching RouteEntry, or None if no route matches.
|
||||
"""
|
||||
@@ -490,6 +507,9 @@ class RouteTable(SimComponent):
|
||||
longest_prefix = prefix_len
|
||||
lowest_metric = route.metric
|
||||
|
||||
if not best_route and self.default_route:
|
||||
best_route = self.default_route
|
||||
|
||||
return best_route
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -521,12 +541,26 @@ class RouterARPCache(ARPCache):
|
||||
super().__init__(sys_log)
|
||||
self.router: Router = router
|
||||
|
||||
def process_arp_packet(self, from_nic: NIC, frame: Frame):
|
||||
def process_arp_packet(
|
||||
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Overridden method to process a received ARP packet in a router-specific way.
|
||||
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
|
||||
|
||||
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
|
||||
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
|
||||
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
|
||||
handling packet TTL (Time To Live).
|
||||
|
||||
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
|
||||
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
|
||||
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
|
||||
to find the best route and reattempts ARP request processing if needed.
|
||||
|
||||
:param from_nic: The NIC that received the ARP packet.
|
||||
:param frame: The original ARP frame.
|
||||
:param frame: The frame containing the ARP packet.
|
||||
:param route_table: The routing table of the router.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
|
||||
"""
|
||||
arp_packet = frame.arp
|
||||
|
||||
@@ -554,7 +588,11 @@ class RouterARPCache(ARPCache):
|
||||
)
|
||||
arp_packet.sender_mac_addr = nic.mac_address
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
|
||||
# ARP Request
|
||||
self.sys_log.info(
|
||||
@@ -565,16 +603,32 @@ class RouterARPCache(ARPCache):
|
||||
self.add_arp_cache_entry(
|
||||
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
|
||||
)
|
||||
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_packet, from_nic)
|
||||
|
||||
# If the target IP matches one of the router's NICs
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
if arp_packet.target_ip_address in nic.ip_network:
|
||||
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
# Check Route Table
|
||||
route = route_table.find_best_route(arp_packet.target_ip_address)
|
||||
if route:
|
||||
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
|
||||
if not nic:
|
||||
if not is_reattempt:
|
||||
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
|
||||
else:
|
||||
self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found")
|
||||
return
|
||||
else:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
|
||||
class RouterICMP(ICMP):
|
||||
"""
|
||||
@@ -645,7 +699,7 @@ class RouterICMP(ICMP):
|
||||
return
|
||||
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
|
||||
for nic in self.router.nics.values():
|
||||
@@ -665,7 +719,48 @@ class RouterICMP(ICMP):
|
||||
|
||||
return
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
|
||||
class RouterNIC(NIC):
|
||||
"""
|
||||
A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality.
|
||||
|
||||
This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed
|
||||
to handle network frames in a manner specific to routers, allowing them to efficiently process
|
||||
and route network traffic.
|
||||
"""
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive and process a network frame from the connected link, provided the NIC is enabled.
|
||||
|
||||
This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL
|
||||
expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for
|
||||
this NIC's MAC address or is a broadcast frame.
|
||||
|
||||
Key Differences from Standard NIC:
|
||||
- Does not perform Layer 3 (IP-based) broadcast checks.
|
||||
- Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the connected node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.mac_address}/{self.ip_address}"
|
||||
|
||||
|
||||
class Router(Node):
|
||||
@@ -678,7 +773,7 @@ class Router(Node):
|
||||
"""
|
||||
|
||||
num_ports: int
|
||||
ethernet_ports: Dict[int, NIC] = {}
|
||||
ethernet_ports: Dict[int, RouterNIC] = {}
|
||||
acl: AccessControlList
|
||||
route_table: RouteTable
|
||||
arp: RouterARPCache
|
||||
@@ -697,7 +792,7 @@ class Router(Node):
|
||||
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
for i in range(1, self.num_ports + 1):
|
||||
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(nic)
|
||||
self.ethernet_ports[i] = nic
|
||||
|
||||
@@ -752,9 +847,9 @@ class Router(Node):
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
Route a given frame from a source NIC to its destination.
|
||||
Process a Frame.
|
||||
|
||||
:param frame: The frame to be routed.
|
||||
:param from_nic: The source network interface.
|
||||
@@ -769,25 +864,57 @@ class Router(Node):
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address)
|
||||
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
self.arp.send_arp_request(
|
||||
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
|
||||
)
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
# TODO: Add sys_log here
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
if frame.ip.dst_ip_address in nic.ip_network:
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
else:
|
||||
pass
|
||||
# TODO: Deal with routing from route tables
|
||||
self._route_frame(frame, from_nic)
|
||||
|
||||
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
|
||||
if route:
|
||||
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
|
||||
if re_attempt and not nic:
|
||||
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
"""
|
||||
@@ -796,7 +923,7 @@ class Router(Node):
|
||||
:param frame: The incoming frame.
|
||||
:param from_nic: The network interface where the frame is coming from.
|
||||
"""
|
||||
route_frame = False
|
||||
process_frame = False
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
@@ -828,12 +955,12 @@ class Router(Node):
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
else:
|
||||
if src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
|
||||
else:
|
||||
# All other traffic
|
||||
route_frame = True
|
||||
if route_frame:
|
||||
self.route_frame(frame, from_nic)
|
||||
process_frame = True
|
||||
if process_frame:
|
||||
self.process_frame(frame, from_nic)
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,7 @@ class Switch(Node):
|
||||
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.switch_ports.items():
|
||||
port._connected_node = self
|
||||
port._port_num_on_node = port_num
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
|
||||
@@ -89,12 +90,12 @@ class Switch(Node):
|
||||
self._add_mac_table_entry(src_mac, incoming_port)
|
||||
|
||||
outgoing_port = self.mac_address_table.get(dst_mac)
|
||||
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
|
||||
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
|
||||
outgoing_port.send_frame(frame)
|
||||
else:
|
||||
# If the destination MAC is not in the table, flood to all ports except incoming
|
||||
for port in self.switch_ports.values():
|
||||
if port != incoming_port:
|
||||
if port.enabled and port != incoming_port:
|
||||
port.send_frame(frame)
|
||||
|
||||
def disconnect_link_from_port(self, link: Link, port_number: int):
|
||||
|
||||
@@ -9,10 +9,10 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
@@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network:
|
||||
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
|
||||
database_service.start()
|
||||
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
|
||||
database_service._process_sql(ddl, None) # noqa
|
||||
database_service._process_sql(ddl, None, None) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database_service._process_sql(insert_statement, None) # noqa
|
||||
database_service._process_sql(insert_statement, None, None) # noqa
|
||||
|
||||
# Web Server
|
||||
web_server = Server(
|
||||
|
||||
35
src/primaite/simulator/network/protocols/ntp.py
Normal file
35
src/primaite/simulator/network/protocols/ntp.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class NTPReply(BaseModel):
|
||||
"""Represents a NTP Reply packet."""
|
||||
|
||||
ntp_datetime: datetime
|
||||
"NTP datetime object set by NTP Server."
|
||||
|
||||
|
||||
class NTPPacket(DataPacket):
|
||||
"""
|
||||
Represents the NTP layer of a network frame.
|
||||
|
||||
:param ntp_request: NTPRequest packet from NTP client.
|
||||
:param ntp_reply: NTPReply packet from NTP Server.
|
||||
"""
|
||||
|
||||
ntp_reply: Optional[NTPReply] = None
|
||||
|
||||
def generate_reply(self, ntp_server_time: datetime) -> NTPPacket:
|
||||
"""Generate a NTPPacket containing the time in a NTPReply object.
|
||||
|
||||
:param time: datetime object representing the time from the NTP server.
|
||||
:return: A new NTPPacket object.
|
||||
"""
|
||||
self.ntp_reply = NTPReply(ntp_datetime=ntp_server_time)
|
||||
return self
|
||||
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -23,7 +23,6 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -66,18 +65,24 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.connected:
|
||||
return self._connect(self.server_ip_address, self.server_password)
|
||||
# already connected
|
||||
return True
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
@@ -92,33 +97,58 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
|
||||
if self.connections.get(connection_id):
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
)
|
||||
return False
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
self.connected = False
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -141,19 +171,17 @@ class DatabaseClient(Application):
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id},
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
|
||||
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
super().run()
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
def query(self, sql: str, is_reattempt: bool = False) -> bool:
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
@@ -164,20 +192,17 @@ class DatabaseClient(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.connected:
|
||||
query_id = str(uuid4())
|
||||
if connection_id is None:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
else:
|
||||
if is_reattempt:
|
||||
return False
|
||||
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
self.query(sql=sql, is_reattempt=True)
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
@@ -192,13 +217,13 @@ class DatabaseClient(Application):
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
if payload["response"] is True:
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
elif payload["type"] == "sql":
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
else:
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Optional
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -149,9 +148,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
self.connect()
|
||||
if self.connected:
|
||||
if len(self.connections):
|
||||
self.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
@@ -183,9 +182,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if self.operating_state != ApplicationOperatingState.RUNNING:
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
if self.server_ip_address and self.payload and self.operating_state:
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
@@ -0,0 +1,192 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DoSAttackStage(IntEnum):
|
||||
"""Enum representing the different stages of a Denial of Service attack."""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Attack not yet started."
|
||||
|
||||
PORT_SCAN = 1
|
||||
"Attack is in discovery stage - checking if provided ip and port are open."
|
||||
|
||||
ATTACKING = 2
|
||||
"Denial of Service attack is in progress."
|
||||
|
||||
COMPLETED = 3
|
||||
"Attack is completed."
|
||||
|
||||
|
||||
class DoSBot(DatabaseClient, Application):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
target_port: Optional[Port] = None
|
||||
"""Port of the target service."""
|
||||
|
||||
payload: Optional[str] = None
|
||||
"""Payload to deliver to the target service as part of the denial of service attack."""
|
||||
|
||||
repeat: bool = False
|
||||
"""If true, the Denial of Service bot will keep performing the attack."""
|
||||
|
||||
attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED
|
||||
"""Current stage of the DoS kill chain."""
|
||||
|
||||
port_scan_p_of_success: float = 0.1
|
||||
"""Probability of port scanning being sucessful."""
|
||||
|
||||
dos_intensity: float = 1.0
|
||||
"""How much of the max sessions will be used by the DoS when attacking."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state of the Denial of Service Bot."""
|
||||
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"target_ip_address",
|
||||
"target_port",
|
||||
"payload",
|
||||
"repeat",
|
||||
"attack_stage",
|
||||
"max_sessions",
|
||||
"port_scan_p_of_success",
|
||||
"dos_intensity",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
|
||||
return rm
|
||||
|
||||
def configure(
|
||||
self,
|
||||
target_ip_address: IPv4Address,
|
||||
target_port: Optional[Port] = Port.POSTGRES_SERVER,
|
||||
payload: Optional[str] = None,
|
||||
repeat: bool = False,
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
dos_intensity: float = 1.0,
|
||||
max_sessions: int = 1000,
|
||||
):
|
||||
"""
|
||||
Configure the Denial of Service bot.
|
||||
|
||||
:param: target_ip_address: The IP address of the Node containing the target service.
|
||||
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
|
||||
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
|
||||
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
|
||||
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
|
||||
:param: dos_intensity: The intensity of the DoS attack.
|
||||
Multiplied with the application's max session - Default is 1.0
|
||||
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
|
||||
"""
|
||||
self.target_ip_address = target_ip_address
|
||||
self.target_port = target_port
|
||||
self.payload = payload
|
||||
self.repeat = repeat
|
||||
self.port_scan_p_of_success = port_scan_p_of_success
|
||||
self.dos_intensity = dos_intensity
|
||||
self.max_sessions = max_sessions
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
self._perform_dos()
|
||||
|
||||
if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING:
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PORT_SCAN` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == DoSAttackStage.NOT_STARTED:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing port scan")
|
||||
# perform the port scan
|
||||
port_is_open = True # Temporary; later we can implement NMAP port scan.
|
||||
if port_is_open:
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
|
||||
def _perform_dos(self):
|
||||
"""
|
||||
Perform the Denial of Service attack.
|
||||
|
||||
DoSBot does this by clogging up the available connections to a service.
|
||||
"""
|
||||
if not self.attack_stage == DoSAttackStage.PORT_SCAN:
|
||||
return
|
||||
self.attack_stage = DoSAttackStage.ATTACKING
|
||||
self.server_ip_address = self.target_ip_address
|
||||
self.port = self.target_port
|
||||
|
||||
dos_sessions = int(float(self.max_sessions) * self.dos_intensity)
|
||||
for i in range(dos_sessions):
|
||||
self.connect()
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, iterate through the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -141,41 +141,76 @@ class SessionManager:
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
"""
|
||||
Receive a payload from the SoftwareManager.
|
||||
Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission.
|
||||
|
||||
If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``.
|
||||
This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an
|
||||
IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP.
|
||||
A new session is established if `session_id` is not provided, and an existing session is used otherwise.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
:param dst_ip_address: The destination IP address or network for broadcast. Optional.
|
||||
:param dst_port: The destination port for the TCP packet. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False.
|
||||
:return: The outcome of sending the frame, or None if sending was unsuccessful.
|
||||
"""
|
||||
is_broadcast = False
|
||||
outbound_nic = None
|
||||
dst_mac_address = None
|
||||
|
||||
# Use session details if session_id is provided
|
||||
if session_id:
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
dst_ip_address = session.with_ip_address
|
||||
dst_port = session.dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
# Determine if the payload is for broadcast or unicast
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
# Handle broadcast transmission
|
||||
if isinstance(dst_ip_address, IPv4Network):
|
||||
is_broadcast = True
|
||||
dst_ip_address = dst_ip_address.broadcast_address
|
||||
if dst_ip_address:
|
||||
# Find a suitable NIC for the broadcast
|
||||
for nic in self.arp_cache.nics.values():
|
||||
if dst_ip_address in nic.ip_network and nic.enabled:
|
||||
dst_mac_address = "ff:ff:ff:ff:ff:ff"
|
||||
outbound_nic = nic
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Resolve MAC address for unicast transmission
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
# Resolve outbound NIC for unicast transmission
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
|
||||
# If MAC address not found, initiate ARP request
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
# Reattempt payload transmission after ARP request
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
# Return None if reattempt fails
|
||||
return
|
||||
|
||||
# Check if outbound NIC and destination MAC address are resolved
|
||||
if not outbound_nic or not dst_mac_address:
|
||||
return False
|
||||
|
||||
# Construct the frame for transmission
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
@@ -189,15 +224,17 @@ class SessionManager:
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
# Manage session for unicast transmission
|
||||
if not (is_broadcast and session_id):
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
# Create a new session if it doesn't exist
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
|
||||
# Send the frame through the NIC
|
||||
return outbound_nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -130,20 +130,28 @@ class SoftwareManager:
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_port=dest_port,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
|
||||
@@ -41,5 +41,5 @@ class Process(Software):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state.update({"operating_state": self.operating_state.name})
|
||||
state.update({"operating_state": self.operating_state.value})
|
||||
return state
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
@@ -24,7 +23,6 @@ class DatabaseService(Service):
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
connections: Dict[str, datetime] = {}
|
||||
|
||||
backup_server_ip: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
@@ -58,7 +56,7 @@ class DatabaseService(Service):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
@@ -151,24 +149,39 @@ class DatabaseService(Service):
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
|
||||
)
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
self.connections[session_id] = datetime.now()
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
status_code = 500
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
return {
|
||||
"status_code": status_code,
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(
|
||||
self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None
|
||||
) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
@@ -180,14 +193,21 @@ class DatabaseService(Service):
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": True,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
@@ -211,19 +231,25 @@ class DatabaseService(Service):
|
||||
:param session_id: The session identifier.
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
result = {"status_code": 500, "data": []}
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
result = self._process_connect(session_id=session_id, password=payload.get("password"))
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if session_id in self.connections:
|
||||
self.connections.pop(session_id)
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
elif payload["type"] == "sql":
|
||||
if session_id in self.connections:
|
||||
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
|
||||
)
|
||||
else:
|
||||
result = {"status_code": 401, "type": "sql"}
|
||||
self.send(payload=result, session_id=session_id)
|
||||
|
||||
@@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC):
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
connected: bool = False
|
||||
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC):
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
|
||||
)
|
||||
if payload.status_code == FTPStatusCode.OK:
|
||||
self.connected = False
|
||||
return True
|
||||
return False
|
||||
return payload.status_code == FTPStatusCode.OK
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
@@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC):
|
||||
return False
|
||||
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
|
||||
@@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC):
|
||||
:type: dest_port: Optional[Port]
|
||||
"""
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
# send retrieve request
|
||||
@@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC):
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
|
||||
return False
|
||||
|
||||
# if PORT succeeded, add the connection as an active connection list
|
||||
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC):
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
connections: Dict[str, IPv4Address] = {}
|
||||
"""Current active connections to the FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -41,7 +37,7 @@ class FTPServer(FTPServiceABC):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
@@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC):
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
|
||||
if payload.ftp_command is not None:
|
||||
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
|
||||
|
||||
@@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC):
|
||||
# check that the port is valid
|
||||
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
|
||||
# return successful connection
|
||||
self.connections[session_id] = session_details.with_ip_address
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
@@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC):
|
||||
return payload
|
||||
|
||||
if payload.ftp_command == FTPCommand.QUIT:
|
||||
self.connections.pop(session_id)
|
||||
self.remove_connection(connection_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
|
||||
132
src/primaite/simulator/system/services/ntp/ntp_client.py
Normal file
132
src/primaite/simulator/system/services/ntp/ntp_client.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPClient(Service):
|
||||
"""Represents a NTP client as a service."""
|
||||
|
||||
ntp_server: Optional[IPv4Address] = None
|
||||
"The NTP server the client sends requests to."
|
||||
time: Optional[datetime] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NTPClient"
|
||||
kwargs["port"] = Port.NTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def configure(self, ntp_server_ip_address: IPv4Address) -> None:
|
||||
"""
|
||||
Set the IP address for the NTP server.
|
||||
|
||||
:param ntp_server_ip_address: IPv4 address of NTP server.
|
||||
:param ntp_client_ip_Address: IPv4 address of NTP client.
|
||||
"""
|
||||
self.ntp_server = ntp_server_ip_address
|
||||
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
The specifics of the software's state, including its health, criticality,
|
||||
and any other pertinent information, should be implemented in subclasses.
|
||||
|
||||
:return: A dictionary containing key-value pairs representing the current state
|
||||
of the software.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: IPv4Address = None,
|
||||
dest_port: [Port] = Port.NTP,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""Requests NTP data from NTP server.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
return super().send(
|
||||
payload=payload,
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=dest_port,
|
||||
session_id=session_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""Receives time data from server.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
if not isinstance(payload, NTPPacket):
|
||||
_LOGGER.debug(f"{payload} is not a NTPPacket")
|
||||
return False
|
||||
if payload.ntp_reply.ntp_datetime:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: \
|
||||
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
|
||||
)
|
||||
self.time = payload.ntp_reply.ntp_datetime
|
||||
return True
|
||||
|
||||
def request_time(self) -> None:
|
||||
"""Send request to ntp_server."""
|
||||
ntp_server_packet = NTPPacket()
|
||||
|
||||
self.send(payload=ntp_server_packet, dest_ip_address=self.ntp_server)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
For each timestep request the time from the NTP server.
|
||||
|
||||
In this instance, if any multi-timestep processes are currently
|
||||
occurring (such as restarting or installation), then they are brought one step closer to
|
||||
being finished.
|
||||
|
||||
:param timestep: The current timestep number. (Amount of time since simulation episode began)
|
||||
:type timestep: int
|
||||
"""
|
||||
self.sys_log.info(f"{self.name} apply_timestep")
|
||||
super().apply_timestep(timestep)
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
# request time from server
|
||||
self.request_time()
|
||||
else:
|
||||
self.sys_log.debug(f"{self.name} ntp client not running")
|
||||
73
src/primaite/simulator/system/services/ntp/ntp_server.py
Normal file
73
src/primaite/simulator/system/services/ntp/ntp_server.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPServer(Service):
|
||||
"""Represents a NTP server as a service."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NTPServer"
|
||||
kwargs["port"] = Port.NTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
The specifics of the software's state, including its health, criticality,
|
||||
and any other pertinent information, should be implemented in subclasses.
|
||||
|
||||
:return: A dictionary containing key-value pairs representing the current
|
||||
state of the software.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including
|
||||
resetting any stateful properties or statistics, and clearing any message
|
||||
queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a request from NTPClient.
|
||||
|
||||
Check that request has a valid IP address.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:param session_id: Id of the session (Optional).
|
||||
|
||||
:return: True if valid NTP request else False.
|
||||
"""
|
||||
if not (isinstance(payload, NTPPacket)):
|
||||
_LOGGER.debug(f"{payload} is not a NTPPacket")
|
||||
return False
|
||||
payload: NTPPacket = payload
|
||||
|
||||
# generate a reply with the current time
|
||||
time = datetime.now()
|
||||
payload = payload.generate_reply(time)
|
||||
# send reply
|
||||
self.send(payload, session_id)
|
||||
return True
|
||||
@@ -41,6 +41,9 @@ class Service(IOSoftware):
|
||||
restart_countdown: Optional[int] = None
|
||||
"If currently restarting, how many timesteps remain until the restart is finished."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -53,7 +56,7 @@ class Service(IOSoftware):
|
||||
if not super()._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.operating_state is not self.operating_state.RUNNING:
|
||||
if self.operating_state is not ServiceOperatingState.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
@@ -75,9 +78,6 @@ class Service(IOSoftware):
|
||||
"""
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -232,7 +234,7 @@ class IOSoftware(Software):
|
||||
|
||||
installing_count: int = 0
|
||||
"The number of times the software has been installed. Default is 0."
|
||||
max_sessions: int = 1
|
||||
max_sessions: int = 100
|
||||
"The maximum number of sessions that the software can handle simultaneously. Default is 0."
|
||||
tcp: bool = True
|
||||
"Indicates if the software uses TCP protocol for communication. Default is True."
|
||||
@@ -240,6 +242,8 @@ class IOSoftware(Software):
|
||||
"Indicates if the software uses UDP protocol for communication. Default is True."
|
||||
port: Port
|
||||
"The port to which the software is connected."
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
@@ -284,23 +288,85 @@ class IOSoftware(Software):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def connections(self) -> Dict[str, Dict]:
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
Returns true if connection successfully created
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
session_details = None
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
if self.connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
|
||||
def clear_connections(self):
|
||||
"""Clears all the connections from the software."""
|
||||
self._connections = {}
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network. It delegates the actual sending process to the SoftwareManager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
@@ -93,25 +93,25 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -126,7 +126,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -514,7 +514,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -522,8 +522,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -31,13 +31,6 @@ agents:
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
@@ -104,25 +97,25 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -137,7 +130,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -525,7 +518,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -533,8 +526,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -37,13 +37,6 @@ agents:
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
@@ -111,25 +104,25 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -144,7 +137,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -532,7 +525,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -540,8 +533,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
@@ -559,25 +552,25 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -592,7 +585,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -980,7 +973,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -988,8 +981,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -35,13 +35,6 @@ agents:
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
@@ -109,25 +102,25 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
# - service_name: backup_service
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -142,7 +135,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -530,7 +523,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -538,8 +531,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -105,25 +105,23 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -138,7 +136,7 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
@@ -526,7 +524,7 @@ agents:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
@@ -534,8 +532,8 @@ agents:
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: web_server
|
||||
service_name: web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -170,7 +170,7 @@ def example_network() -> Network:
|
||||
-------------- --------------
|
||||
| client_1 |----- ----| server_1 |
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
------| switch_1 |------| router_1 |------| switch_2 |------
|
||||
------| switch_2 |------| router_1 |------| switch_1 |------
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
| client_2 |---- ----| server_2 |
|
||||
-------------- --------------
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
|
||||
def test_data_manipulation(uc2_network):
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_file_observation():
|
||||
state = sim.describe_state()
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
|
||||
)
|
||||
assert dog_file_obs.observe(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
180
tests/integration_tests/network/test_broadcast.py
Normal file
180
tests/integration_tests/network/test_broadcast.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class BroadcastService(Service):
|
||||
"""A service for sending broadcast and unicast messages over a network."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default service properties for broadcasting
|
||||
kwargs["name"] = "BroadcastService"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
# Implement state description for the service
|
||||
pass
|
||||
|
||||
def unicast(self, ip_address: IPv4Address):
|
||||
# Send a unicast payload to a specific IP address
|
||||
super().send(
|
||||
payload="unicast",
|
||||
dest_ip_address=ip_address,
|
||||
dest_port=Port.HTTP,
|
||||
)
|
||||
|
||||
def broadcast(self, ip_network: IPv4Network):
|
||||
# Send a broadcast payload to an entire IP network
|
||||
super().send(
|
||||
payload="broadcast",
|
||||
dest_ip_address=ip_network,
|
||||
dest_port=Port.HTTP,
|
||||
)
|
||||
|
||||
|
||||
class BroadcastClient(Application):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default client properties
|
||||
kwargs["name"] = "BroadcastClient"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
# Implement state description for the application
|
||||
pass
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
# Append received payloads to the list and print a message
|
||||
self.payloads_received.append(payload)
|
||||
print(f"Payload: {payload} received on node {self.sys_log.hostname}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def broadcast_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
client_1 = Computer(
|
||||
hostname="client_1",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
client_1.power_on()
|
||||
client_1.software_manager.install(BroadcastClient)
|
||||
application_1 = client_1.software_manager.software["BroadcastClient"]
|
||||
application_1.run()
|
||||
|
||||
client_2 = Computer(
|
||||
hostname="client_2",
|
||||
ip_address="192.168.1.3",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(BroadcastClient)
|
||||
application_2 = client_2.software_manager.software["BroadcastClient"]
|
||||
application_2.run()
|
||||
|
||||
server_1 = Server(
|
||||
hostname="server_1",
|
||||
ip_address="192.168.1.1",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
server_1.power_on()
|
||||
|
||||
server_1.software_manager.install(BroadcastService)
|
||||
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
|
||||
service.start()
|
||||
|
||||
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
|
||||
switch_1.power_on()
|
||||
|
||||
network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1])
|
||||
network.connect(endpoint_a=client_2.ethernet_port[1], endpoint_b=switch_1.switch_ports[2])
|
||||
network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[3])
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
|
||||
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
|
||||
"BroadcastClient"
|
||||
]
|
||||
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
|
||||
"BroadcastClient"
|
||||
]
|
||||
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
|
||||
"BroadcastService"
|
||||
]
|
||||
|
||||
return service, client_1, client_2
|
||||
|
||||
|
||||
def test_broadcast_correct_subnet(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.broadcast(IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert client_1.payloads_received == ["broadcast"]
|
||||
assert client_2.payloads_received == ["broadcast"]
|
||||
|
||||
|
||||
def test_broadcast_incorrect_subnet(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.broadcast(IPv4Network("192.168.2.0/24"))
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
|
||||
def test_unicast_correct_address(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.unicast(IPv4Address("192.168.1.2"))
|
||||
|
||||
assert client_1.payloads_received == ["unicast"]
|
||||
assert not client_2.payloads_received
|
||||
|
||||
|
||||
def test_unicast_incorrect_address(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.unicast(IPv4Address("192.168.2.2"))
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
@@ -1,11 +1,16 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -34,6 +39,69 @@ def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]:
|
||||
return pc_a, pc_b, router_1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_hop_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
# Configure PC A
|
||||
pc_a = Computer(
|
||||
hostname="pc_a",
|
||||
ip_address="192.168.0.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.0.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_a.power_on()
|
||||
network.add_node(pc_a)
|
||||
|
||||
# Configure Router 1
|
||||
router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1.power_on()
|
||||
network.add_node(router_1)
|
||||
|
||||
# Configure the connection between PC A and Router 1 port 2
|
||||
router_1.configure_port(2, "192.168.0.1", "255.255.255.0")
|
||||
network.connect(pc_a.ethernet_port[1], router_1.ethernet_ports[2])
|
||||
router_1.enable_port(2)
|
||||
|
||||
# Configure Router 1 ACLs
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
# Configure PC B
|
||||
pc_b = Computer(
|
||||
hostname="pc_b",
|
||||
ip_address="192.168.2.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.2.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_b.power_on()
|
||||
network.add_node(pc_b)
|
||||
|
||||
# Configure Router 2
|
||||
router_2 = Router(hostname="router_2", start_up_duration=0)
|
||||
router_2.power_on()
|
||||
network.add_node(router_2)
|
||||
|
||||
# Configure the connection between PC B and Router 2 port 2
|
||||
router_2.configure_port(2, "192.168.2.1", "255.255.255.0")
|
||||
network.connect(pc_b.ethernet_port[1], router_2.ethernet_ports[2])
|
||||
router_2.enable_port(2)
|
||||
|
||||
# Configure Router 2 ACLs
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
# Configure the connection between Router 1 port 1 and Router 2 port 1
|
||||
router_2.configure_port(1, "192.168.1.2", "255.255.255.252")
|
||||
router_1.configure_port(1, "192.168.1.1", "255.255.255.252")
|
||||
network.connect(router_1.ethernet_ports[1], router_2.ethernet_ports[1])
|
||||
router_1.enable_port(1)
|
||||
router_2.enable_port(1)
|
||||
return network
|
||||
|
||||
|
||||
def test_ping_default_gateway(pc_a_pc_b_router_1):
|
||||
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
|
||||
|
||||
@@ -50,3 +118,68 @@ def test_host_on_other_subnet(pc_a_pc_b_router_1):
|
||||
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
|
||||
|
||||
assert pc_a.ping("192.168.1.10")
|
||||
|
||||
|
||||
def test_no_route_no_ping(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
assert not pc_a.ping(pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
|
||||
def test_with_routes_can_ping(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
|
||||
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
|
||||
|
||||
# Configure Route from Router 1 to PC B subnet
|
||||
router_1.route_table.add_route(
|
||||
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
|
||||
)
|
||||
|
||||
# Configure Route from Router 2 to PC A subnet
|
||||
router_2.route_table.add_route(
|
||||
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
|
||||
)
|
||||
|
||||
assert pc_a.ping(pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
|
||||
def test_routing_services(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
pc_a.software_manager.install(NTPClient)
|
||||
ntp_client = pc_a.software_manager.software["NTPClient"]
|
||||
ntp_client.start()
|
||||
|
||||
pc_b.software_manager.install(NTPServer)
|
||||
pc_b.software_manager.software["NTPServer"].start()
|
||||
|
||||
ntp_client.configure(ntp_server_ip_address=pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
|
||||
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
|
||||
|
||||
assert ntp_client.time is None
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is None
|
||||
|
||||
# Configure Route from Router 1 to PC B subnet
|
||||
router_1.route_table.add_route(
|
||||
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
|
||||
)
|
||||
|
||||
# Configure Route from Router 2 to PC A subnet
|
||||
router_2.route_table.add_route(
|
||||
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
|
||||
)
|
||||
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is not None
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]:
|
||||
computer, server = client_server
|
||||
|
||||
# Install DoSBot on computer
|
||||
computer.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
|
||||
target_port=Port.POSTGRES_SERVER,
|
||||
)
|
||||
|
||||
# Install DB Server service on server
|
||||
server.software_manager.install(DatabaseService)
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
db_server_service.start()
|
||||
|
||||
return dos_bot, computer, db_server_service, server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot_db_server_green_client(example_network) -> Network:
|
||||
network: Network = example_network
|
||||
|
||||
router_1: Router = example_network.get_node_by_hostname("router_1")
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
|
||||
)
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
|
||||
# install DoS bot on client 1
|
||||
client_1.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
|
||||
target_port=Port.POSTGRES_SERVER,
|
||||
)
|
||||
|
||||
# install db server service on server
|
||||
server.software_manager.install(DatabaseService)
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
db_server_service.start()
|
||||
|
||||
# Install DB client (green) on client 2
|
||||
client_2.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
|
||||
database_client.run()
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def test_repeating_dos_attack(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = True
|
||||
dos_bot.run()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
|
||||
def test_non_repeating_dos_attack(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = False
|
||||
dos_bot.run()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert len(dos_bot.connections) == 0
|
||||
assert len(db_server_service.connections) == 0
|
||||
assert len(dos_bot.connections) == 0
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_dos_bot_database_service_connection(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
dos_bot.operating_state = ApplicationOperatingState.RUNNING
|
||||
dos_bot.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
dos_bot._perform_dos()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
|
||||
def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client):
|
||||
network: Network = dos_bot_db_server_green_client
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
|
||||
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = False
|
||||
dos_bot.run()
|
||||
|
||||
# DoS bot fills up connection of db server service
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(green_db_client.connections) == 0
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
# db server service is overwhelmed
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
# green agent tries to connect but fails because service is overwhelmed
|
||||
assert green_db_client.connect() is False
|
||||
assert len(green_db_client.connections) == 0
|
||||
@@ -65,8 +65,8 @@ def test_server_turns_off_application(populated_node):
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_application_cannot_be_turned_on_when_server_is_off(populated_node):
|
||||
"""Check that the application cannot be started when the server is off."""
|
||||
def test_application_cannot_be_turned_on_when_computer_is_off(populated_node):
|
||||
"""Check that the application cannot be started when the computer is off."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
@@ -86,8 +86,8 @@ def test_application_cannot_be_turned_on_when_server_is_off(populated_node):
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_server_turns_on_application(populated_node):
|
||||
"""Check that turning on the server turns on application."""
|
||||
def test_computer_runs_applications(populated_node):
|
||||
"""Check that turning on the computer will turn on applications."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
@@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
def test_database_client_server_connection(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
@pytest.fixture(scope="function")
|
||||
def peer_to_peer() -> Tuple[Node, Node]:
|
||||
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
|
||||
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
|
||||
node_a.connect_nic(nic_a)
|
||||
node_a.software_manager.get_open_ports()
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
|
||||
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
|
||||
node_b.connect_nic(nic_b)
|
||||
|
||||
Link(endpoint_a=nic_a, endpoint_b=nic_b)
|
||||
|
||||
assert node_a.ping("192.168.0.11")
|
||||
|
||||
node_a.software_manager.install(DatabaseClient)
|
||||
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
node_a.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
node_b.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.start()
|
||||
return node_a, node_b
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def peer_to_peer_secure_db() -> Tuple[Node, Node]:
|
||||
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
|
||||
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
|
||||
node_a.connect_nic(nic_a)
|
||||
node_a.software_manager.get_open_ports()
|
||||
|
||||
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
|
||||
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
|
||||
node_b.connect_nic(nic_b)
|
||||
|
||||
Link(endpoint_a=nic_a, endpoint_b=nic_b)
|
||||
|
||||
assert node_a.ping("192.168.0.11")
|
||||
|
||||
node_a.software_manager.install(DatabaseClient)
|
||||
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
node_a.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
node_b.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.password = "12345"
|
||||
database_service.start()
|
||||
return node_a, node_b
|
||||
|
||||
|
||||
def test_database_client_server_connection(peer_to_peer):
|
||||
node_a, node_b = peer_to_peer
|
||||
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
db_client.disconnect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_server_correct_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
def test_database_client_server_correct_password(peer_to_peer_secure_db):
|
||||
node_a, node_b = peer_to_peer_secure_db
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_client.disconnect()
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
|
||||
db_service.password = "12345"
|
||||
|
||||
assert db_client.connect()
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345")
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
|
||||
def test_database_client_server_incorrect_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
def test_database_client_server_incorrect_password(peer_to_peer_secure_db):
|
||||
node_a, node_b = peer_to_peer_secure_db
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_client.disconnect()
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
|
||||
db_service.password = "12345"
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
assert not db_client.connect()
|
||||
# should fail
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass")
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
assert db_client.connected
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client.connect()
|
||||
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
@@ -66,13 +121,13 @@ def test_database_client_query(uc2_network):
|
||||
def test_create_database_backup(uc2_network):
|
||||
"""Run the backup_database method and check if the FTP server has the relevant file."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
# back up should be created
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
|
||||
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
|
||||
# backup file should exist in the backup server
|
||||
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
|
||||
@@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network):
|
||||
def test_restore_backup(uc2_network):
|
||||
"""Run the restore_backup method and check if the backup is properly restored."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
# create a back up
|
||||
assert db_service.backup_database() is True
|
||||
@@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
assert db_client.connected
|
||||
assert len(db_client.connections)
|
||||
|
||||
assert db_client.query("SELECT") is True
|
||||
|
||||
|
||||
86
tests/integration_tests/system/test_ntp_client_server.py
Normal file
86
tests/integration_tests/system/test_ntp_client_server.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from ipaddress import IPv4Address
|
||||
from time import sleep
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
# Create simple network for testing
|
||||
# Define one node to be an NTP server and another node to be a NTP Client.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_ntp_network(client_server) -> Tuple[NTPClient, Computer, NTPServer, Server]:
|
||||
"""
|
||||
+------------+ +------------+
|
||||
| ntp | | ntp |
|
||||
| client_1 +------------+ server_1 |
|
||||
| | | |
|
||||
+------------+ +------------+
|
||||
|
||||
"""
|
||||
client, server = client_server
|
||||
|
||||
server.power_on()
|
||||
server.software_manager.install(NTPServer)
|
||||
ntp_server: NTPServer = server.software_manager.software.get("NTPServer")
|
||||
ntp_server.start()
|
||||
|
||||
client.power_on()
|
||||
client.software_manager.install(NTPClient)
|
||||
ntp_client: NTPClient = client.software_manager.software.get("NTPClient")
|
||||
ntp_client.start()
|
||||
|
||||
return ntp_client, client, ntp_server, server
|
||||
|
||||
|
||||
def test_ntp_client_server(create_ntp_network):
|
||||
ntp_client, client, ntp_server, server = create_ntp_network
|
||||
|
||||
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
|
||||
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
|
||||
|
||||
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
|
||||
|
||||
assert ntp_client.time is None
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is not None
|
||||
first_time = ntp_client.time
|
||||
sleep(0.1)
|
||||
ntp_client.apply_timestep(1) # Check time advances
|
||||
second_time = ntp_client.time
|
||||
assert first_time < second_time
|
||||
|
||||
|
||||
# Test ntp client behaviour when ntp server is unavailable.
|
||||
def test_ntp_server_failure(create_ntp_network):
|
||||
ntp_client, client, ntp_server, server = create_ntp_network
|
||||
|
||||
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
|
||||
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
|
||||
|
||||
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
|
||||
|
||||
# Turn off ntp server.
|
||||
ntp_server.stop()
|
||||
assert ntp_server.operating_state == ServiceOperatingState.STOPPED
|
||||
# And request a time update.
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is None
|
||||
|
||||
# Restart ntp server.
|
||||
ntp_server.start()
|
||||
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is not None
|
||||
@@ -53,12 +53,12 @@ def test_node_os_scan(node, service, application):
|
||||
# TODO implement processes
|
||||
|
||||
# add services to node
|
||||
service.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_service(service=service)
|
||||
assert service.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# add application to node
|
||||
application.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
@@ -101,7 +101,7 @@ def test_node_red_scan(node, service, application):
|
||||
assert service.revealed_to_red is False
|
||||
|
||||
# add application to node
|
||||
application.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.revealed_to_red is False
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
|
||||
DataManipulationAttackStage,
|
||||
DataManipulationBot,
|
||||
)
|
||||
@@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
|
||||
dm_bot._perform_data_manipulation(p_of_success=1.0)
|
||||
|
||||
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
|
||||
assert dm_bot.connected
|
||||
assert len(dm_bot.connections)
|
||||
@@ -0,0 +1,90 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot() -> DoSBot:
|
||||
computer = Computer(
|
||||
hostname="compromised_pc",
|
||||
ip_address="192.168.0.1",
|
||||
subnet_mask="255.255.255.0",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
computer.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
|
||||
dos_bot.set_original_state()
|
||||
return dos_bot
|
||||
|
||||
|
||||
def test_dos_bot_creation(dos_bot):
|
||||
"""Test that the DoS bot is installed on a node."""
|
||||
assert dos_bot is not None
|
||||
|
||||
|
||||
def test_dos_bot_reset(dos_bot):
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
|
||||
# should reset the relevant items
|
||||
dos_bot.reset_component_for_episode(episode=0)
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
dos_bot.set_original_state()
|
||||
dos_bot.reset_component_for_episode(episode=1)
|
||||
# should reset to the configured value
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
|
||||
assert dos_bot.target_port is Port.HTTP
|
||||
assert dos_bot.payload == "payload"
|
||||
assert dos_bot.repeat is True
|
||||
|
||||
|
||||
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
|
||||
dos_bot_node: Computer = dos_bot.parent
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.ON
|
||||
|
||||
dos_bot_node.power_off()
|
||||
|
||||
for i in range(dos_bot_node.shut_down_duration + 1):
|
||||
dos_bot_node.apply_timestep(timestep=i)
|
||||
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.OFF
|
||||
|
||||
dos_bot._application_loop()
|
||||
|
||||
# assert not run
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
|
||||
|
||||
def test_dos_bot_not_configured(dos_bot):
|
||||
dos_bot.target_ip_address = None
|
||||
|
||||
dos_bot.operating_state = ApplicationOperatingState.RUNNING
|
||||
dos_bot._application_loop()
|
||||
|
||||
|
||||
def test_dos_bot_perform_port_scan(dos_bot):
|
||||
dos_bot._perform_port_scan(p_of_success=1)
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN
|
||||
@@ -0,0 +1,50 @@
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
def test_scan(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
application.scan()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_visible == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_run_application(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_close_application(application):
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
application.close()
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_application_describe_states(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.run()
|
||||
assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
assert SoftwareHealthState.COMPROMISED.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.patch()
|
||||
assert SoftwareHealthState.PATCHING.value == application.describe_state().get("health_state_actual")
|
||||
@@ -1,5 +1,6 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -62,18 +63,26 @@ def test_disconnect_when_client_is_closed(database_client_on_computer):
|
||||
|
||||
|
||||
def test_disconnect(database_client_on_computer):
|
||||
"""Database client should set connected to False and remove the database server ip address."""
|
||||
"""Database client should remove the connection."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = True
|
||||
database_client._connections[str(uuid4())] = {"item": True}
|
||||
assert len(database_client.connections) == 1
|
||||
|
||||
assert database_client.operating_state is ApplicationOperatingState.RUNNING
|
||||
assert database_client.server_ip_address is not None
|
||||
|
||||
database_client.disconnect()
|
||||
|
||||
assert database_client.connected is False
|
||||
assert database_client.server_ip_address is None
|
||||
assert len(database_client.connections) == 0
|
||||
|
||||
uuid = str(uuid4())
|
||||
database_client._connections[uuid] = {"item": True}
|
||||
assert len(database_client.connections) == 1
|
||||
|
||||
database_client.disconnect(connection_id=uuid)
|
||||
|
||||
assert len(database_client.connections) == 0
|
||||
|
||||
|
||||
def test_query_when_client_is_closed(database_client_on_computer):
|
||||
@@ -86,19 +95,6 @@ def test_query_when_client_is_closed(database_client_on_computer):
|
||||
assert database_client.query(sql="test") is False
|
||||
|
||||
|
||||
def test_query_failed_reattempt(database_client_on_computer):
|
||||
"""Database client query should return False if the reattempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
def return_false():
|
||||
return False
|
||||
|
||||
database_client.connect = return_false
|
||||
|
||||
database_client.connected = False
|
||||
assert database_client.query(sql="test", is_reattempt=True) is False
|
||||
|
||||
|
||||
def test_query_fail_to_connect(database_client_on_computer):
|
||||
"""Database client query should return False if the connect attempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -17,52 +19,174 @@ def test_scan(service):
|
||||
|
||||
def test_start_service(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
service.start()
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_stop_service(service):
|
||||
service.start()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.stop()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_pause_and_resume_service(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
service.resume()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
service.pause()
|
||||
assert service.operating_state == ServiceOperatingState.PAUSED
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.resume()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_restart(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
service.restart()
|
||||
# Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.start()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
service.restart()
|
||||
# Service is RUNNING. Restart should work
|
||||
assert service.operating_state == ServiceOperatingState.RESTARTING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
timestep = 0
|
||||
while service.operating_state == ServiceOperatingState.RESTARTING:
|
||||
service.apply_timestep(timestep)
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
timestep += 1
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_restart_compromised(service):
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
# compromise the service
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.restart()
|
||||
assert service.operating_state == ServiceOperatingState.RESTARTING
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
"""
|
||||
Service should be compromised even after reset.
|
||||
|
||||
Only way to remove compromised status is via patching.
|
||||
"""
|
||||
|
||||
timestep = 0
|
||||
while service.operating_state == ServiceOperatingState.RESTARTING:
|
||||
service.apply_timestep(timestep)
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
timestep += 1
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
|
||||
def test_compromised_service_remains_compromised(service):
|
||||
"""
|
||||
Tests that a compromised service stays compromised.
|
||||
|
||||
The only way that the service can be uncompromised is by running patch.
|
||||
"""
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.stop()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.disable()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.enable()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.pause()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.resume()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
|
||||
def test_service_patching(service):
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.patch()
|
||||
assert service.health_state_actual == SoftwareHealthState.PATCHING
|
||||
|
||||
for i in range(service.patching_duration + 1):
|
||||
service.apply_timestep(i)
|
||||
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_enable_disable(service):
|
||||
service.disable()
|
||||
assert service.operating_state == ServiceOperatingState.DISABLED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.enable()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
|
||||
def test_overwhelm_service(service):
|
||||
service.max_sessions = 2
|
||||
service.start()
|
||||
|
||||
uuid = str(uuid4())
|
||||
assert service.add_connection(connection_id=uuid) # should be true
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=uuid) # fails because connection already exists
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert service.add_connection(connection_id=str(uuid4())) # succeed
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity
|
||||
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
|
||||
def test_create_and_remove_connections(service):
|
||||
service.start()
|
||||
uuid = str(uuid4())
|
||||
|
||||
assert service.add_connection(connection_id=uuid) # should be true
|
||||
assert len(service.connections) == 1
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert service.remove_connection(connection_id=uuid) # should be true
|
||||
assert len(service.connections) == 0
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.software import Software, SoftwareHealthState
|
||||
|
||||
|
||||
class TestSoftware(Software):
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def software(file_system):
|
||||
return TestSoftware(
|
||||
name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
|
||||
)
|
||||
|
||||
|
||||
def test_software_creation(software):
|
||||
assert software is not None
|
||||
|
||||
|
||||
def test_software_set_health_state(software):
|
||||
assert software.health_state_actual == SoftwareHealthState.UNUSED
|
||||
software.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert software.health_state_actual == SoftwareHealthState.GOOD
|
||||
Reference in New Issue
Block a user