Merge remote-tracking branch 'origin/dev' into dev-v3.0.0b6

This commit is contained in:
Marek Wolan
2024-01-29 10:26:28 +00:00
57 changed files with 2329 additions and 461 deletions

View File

@@ -18,36 +18,43 @@ parameters:
py: '3.8'
img: 'ubuntu-latest'
every_time: false
publish_coverage: false
- job_name: 'UbuntuPython310'
py: '3.10'
img: 'ubuntu-latest'
every_time: true
publish_coverage: true
- job_name: 'WindowsPython38'
py: '3.8'
img: 'windows-latest'
every_time: false
publish_coverage: false
- job_name: 'WindowsPython310'
py: '3.10'
img: 'windows-latest'
every_time: false
publish_coverage: false
- job_name: 'MacOSPython38'
py: '3.8'
img: 'macOS-latest'
every_time: false
publish_coverage: false
- job_name: 'MacOSPython310'
py: '3.10'
img: 'macOS-latest'
every_time: false
publish_coverage: false
stages:
- stage: Test
jobs:
- ${{ each item in parameters.matrix }}:
- job: ${{ item.job_name }}
timeoutInMinutes: 90
cancelTimeoutInMinutes: 1
pool:
vmImage: ${{ item.img }}
condition: or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} )
condition: and(succeeded(), or( eq(variables['Build.Reason'], 'PullRequest'), ${{ item.every_time }} ))
steps:
- task: UsePythonVersion@0
@@ -109,12 +116,12 @@ stages:
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
# publish the html report - so we can debug the coverage if needed
condition: ${{ item.every_time }} # should only be run once
condition: ${{ item.publish_coverage }} # should only be run once
artifact: coverage_report
- task: PublishCodeCoverageResults@2
# publish the code coverage so it can be viewed in the run coverage page
condition: ${{ item.every_time }} # should only be run once
condition: ${{ item.publish_coverage }} # should only be run once
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'

1
.gitignore vendored
View File

@@ -157,3 +157,4 @@ benchmark/output
src/primaite/notebooks/scratch.py
sandbox.py
sandbox/
sandbox.ipynb

View File

@@ -55,6 +55,19 @@ SessionManager.
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
- NTP Services: `NTPClient` and `NTPServer`
- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic.
- **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required.
- **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance.
- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework.
### Changed
- Integrated the RouteTable into the Routers frame processing.
- Frames are now dropped when their TTL reaches 0
- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts.
- **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting.
- **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios.
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`

View File

@@ -0,0 +1,54 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
NTP Client Server
=================
NTP Server
----------
The ``NTPServer`` provides a NTP Server simulation by extending the base Service class.
NTP Client
----------
The ``NTPClient`` provides a NTP Client simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Simulates NTP requests and NTPPacket transfer across a network
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 123 by default.
Implementation
^^^^^^^^^^^^^^
- NTP request and responses use a ``NTPPacket`` object
- Extends Service class for integration with ``SoftwareManager``.
NTP Client
----------
The NTPClient provides a client interface for connecting to the ``NTPServer``.
Key features
^^^^^^^^^^^^
- Connects to the ``NTPServer`` via the ``SoftwareManager``.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 123 by default.
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Provides easy interface for Nodes to find IP addresses via domain names.
- Extends base Service class.

View File

@@ -107,23 +107,23 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: web_server_web_service
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +138,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -528,7 +528,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -536,8 +536,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -627,7 +627,7 @@ class ActionManager:
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
class DataManipulationAgent(AbstractScriptedAgent):

View File

@@ -45,6 +45,7 @@ class AgentSettings(BaseModel):
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
@@ -180,7 +181,7 @@ class ProxyAgent(AbstractAgent):
reward_function=reward_function,
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""

View File

@@ -41,8 +41,7 @@ class AbstractObservation(ABC):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
@@ -54,12 +53,12 @@ class FileObservation(AbstractObservation):
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
@@ -121,7 +120,7 @@ class ServiceObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -166,7 +165,7 @@ class ServiceObservation(AbstractObservation):
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
return cls(where=parent_where + ["services", config["service_name"]])
class LinkObservation(AbstractObservation):
@@ -183,7 +182,7 @@ class LinkObservation(AbstractObservation):
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
@@ -249,7 +248,7 @@ class FolderObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
@@ -328,7 +327,7 @@ class FolderObservation(AbstractObservation):
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
@@ -354,7 +353,7 @@ class NicObservation(AbstractObservation):
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
@@ -391,12 +390,12 @@ class NicObservation(AbstractObservation):
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
return cls(where=parent_where + ["NICs", config["nic_num"]])
class NodeObservation(AbstractObservation):
@@ -419,9 +418,9 @@ class NodeObservation(AbstractObservation):
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
@@ -430,7 +429,7 @@ class NodeObservation(AbstractObservation):
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:param nics: Mapping between position in observation space and NIC idx, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
@@ -548,11 +547,11 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = game.ref_map_nodes[config["node_ref"]]
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_uuid]
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_uuid]
where = parent_where + ["nodes", node_hostname]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
@@ -563,8 +562,8 @@ class NodeObservation(AbstractObservation):
)
for c in folder_configs
]
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
@@ -605,7 +604,7 @@ class AclObservation(AbstractObservation):
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_uuid>,'acl','acl']
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
@@ -732,12 +731,12 @@ class AclObservation(AbstractObservation):
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -867,6 +866,7 @@ class UC2BlueObservation(AbstractObservation):
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]

View File

@@ -82,11 +82,11 @@ class DummyReward(AbstractReward):
class DatabaseFileIntegrity(AbstractReward):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the database file.
:type node_uuid: str
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
@@ -95,7 +95,7 @@ class DatabaseFileIntegrity(AbstractReward):
self.location_in_state = [
"network",
"nodes",
node_uuid,
node_hostname,
"file_system",
"folders",
folder_name,
@@ -136,49 +136,29 @@ class DatabaseFileIntegrity(AbstractReward):
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_ref = config.get("node_ref")
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not node_ref:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not folder_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not file_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
f"found in the simulation"
)
)
return DummyReward() # TODO: better error handling
if not (node_hostname and folder_name and file_name):
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_uuid: str, service_uuid: str) -> None:
def __init__(self, node_hostname: str, service_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the web server service.
:type node_uuid: str
:param service_uuid: UUID of the web server service.
:type service_uuid: str
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
"""
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state.
@@ -209,26 +189,17 @@ class WebServer404Penalty(AbstractReward):
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_ref = config.get("node_ref")
service_ref = config.get("service_ref")
if not (node_ref and service_ref):
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref]
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator."
)
_LOGGER.warning(msg)
return DummyReward() # TODO: consider erroring here as well
raise ValueError(msg)
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
return cls(node_hostname=node_hostname, service_name=service_name)
class RewardFunction:

View File

@@ -18,13 +18,15 @@ from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -244,6 +246,8 @@ class PrimaiteGame:
"WebServer": WebServer,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
}
if service_type in service_types_mapping:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
@@ -292,6 +296,7 @@ class PrimaiteGame:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),

View File

@@ -102,7 +102,7 @@ class DomainController(SimComponent):
:rtype: Dict
"""
state = super().describe_state()
state.update({"accounts": {uuid: acct.describe_state() for uuid, acct in self.accounts.items()}})
state.update({"accounts": {acct.username: acct.describe_state() for acct in self.accounts.values()}})
return state
def _register_account(self, account: Account) -> None:

View File

@@ -199,10 +199,24 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
for uuid, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
port_a = link.endpoint_a._port_num_on_node
port_b = link.endpoint_b._port_num_on_node
state["links"][uuid] = link.describe_state()
state["links"][uuid]["hostname_a"] = hostname_a
state["links"][uuid]["hostname_b"] = hostname_b
state["links"][uuid]["port_a"] = port_a
state["links"][uuid]["port_b"] = port_b
return state
def add_node(self, node: Node) -> None:

View File

@@ -0,0 +1,148 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
"""
Calculate the minimum number of network switches required to connect a given number of nodes.
Each switch is assumed to have one port reserved for connecting to a router, reducing the effective
number of ports available for PCs. The function calculates the total number of switches needed
to accommodate all nodes under this constraint.
:param num_nodes: The total number of nodes that need to be connected in the network.
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
:return: The minimum number of switches required to connect all PCs.
Example:
>>> num_of_switches_required(5)
1
>>> num_of_switches_required(24,24)
2
>>> num_of_switches_required(48,24)
3
>>> num_of_switches_required(25,10)
3
"""
# Reduce the effective number of switch ports by 1 to leave space for the router
effective_switch_ports = max_switch_ports - 1
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
full_switches = num_nodes // effective_switch_ports
extra_pcs = num_nodes % effective_switch_ports
# Return the total number of switches required
return full_switches + (1 if extra_pcs > 0 else 0)
def create_office_lan(
lan_name: str,
subnet_base: int,
pcs_ip_block_start: int,
num_pcs: int,
network: Optional[Network] = None,
include_router: bool = True,
) -> Network:
"""
Creates a 2-Tier or 3-Tier office local area network (LAN).
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
:param str lan_name: The name to be assigned to the LAN.
:param int subnet_base: The subnet base number to be used in the IP addresses.
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
:param int num_pcs: The number of PCs to be added to the LAN.
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
created.
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
:return: The network object with the LAN components added.
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
"""
# Initialise the network if not provided
if not network:
network = Network()
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
effective_switch_ports = 23 # One port less for router connection
if pcs_ip_block_start <= num_of_switches:
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
# Initialise the default gateway to None
default_gateway = None
# Optionally include a router in the LAN
if include_router:
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
router.power_on()
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
network.add_node(router)
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
router.enable_port(1)
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
else:
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_switch_ports:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
if num_of_switches > 1:
core_switch_port += 1
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
else:
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{lan_name}",
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc.power_on()
network.add_node(pc)
# Connect the PC to the switch
switch_port += 1
network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1])
switch.switch_ports[switch_port].enable()
return network

View File

@@ -4,7 +4,7 @@ import re
import secrets
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
@@ -91,6 +91,8 @@ class NIC(SimComponent):
"Indicates if the NIC supports Wake-on-LAN functionality."
_connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
_port_num_on_node: Optional[int] = None
"Which port number is assigned on this NIC"
_connected_link: Optional[Link] = None
"The Link to which the NIC is connected."
enabled: bool = False
@@ -148,7 +150,7 @@ class NIC(SimComponent):
state = super().describe_state()
state.update(
{
"ip_adress": str(self.ip_address),
"ip_address": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"mac_address": self.mac_address,
"speed": self.speed,
@@ -272,18 +274,40 @@ class NIC(SimComponent):
def receive_frame(self, frame: Frame) -> bool:
"""
Receive a network frame from the connected link if the NIC is enabled.
Receive a network frame from the connected link, processing it if the NIC is enabled.
The Frame is passed to the Node.
This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks
if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the
connected node. The method also handles the discarding of frames with TTL expired and logs this event.
:param frame: The network frame being received.
The frame's reception is based on various conditions:
- If the NIC is disabled, the frame is not processed.
- If the TTL of the frame reaches zero after decrement, it is discarded and logged.
- If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted.
- All other frames are dropped and logged or printed to the console.
:param frame: The network frame being received. This should be an instance of the Frame class.
:return: Returns True if the frame is processed and passed to the node, False otherwise.
"""
if self.enabled:
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
return False
frame.set_received_timestamp()
self.pcap.capture(frame)
# If this destination or is broadcast
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
accept_frame = False
# Check if it's a broadcast:
if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}:
accept_frame = True
else:
if frame.ethernet.dst_mac_addr == self.mac_address:
accept_frame = True
if accept_frame:
self._connected_node.receive_frame(frame=frame, from_nic=self)
return True
return False
@@ -311,6 +335,8 @@ class SwitchPort(SimComponent):
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
_connected_node: Optional[Node] = None
"The Node to which the SwitchPort is connected."
_port_num_on_node: Optional[int] = None
"The port num on the connected node."
_connected_link: Optional[Link] = None
"The Link to which the SwitchPort is connected."
enabled: bool = False
@@ -432,6 +458,9 @@ class SwitchPort(SimComponent):
"""
if self.enabled:
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
return False
self.pcap.capture(frame)
connected_node: Node = self._connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
@@ -497,8 +526,8 @@ class Link(SimComponent):
state = super().describe_state()
state.update(
{
"endpoint_a": self.endpoint_a.uuid,
"endpoint_b": self.endpoint_b.uuid,
"endpoint_a": self.endpoint_a.uuid, # TODO: consider if using UUID is the best way to do this
"endpoint_b": self.endpoint_b.uuid, # TODO: consider if using UUID is the best way to do this
"bandwidth": self.bandwidth,
"current_load": self.current_load,
}
@@ -667,17 +696,30 @@ class ARPCache:
"""Clear the entire ARP cache, removing all stored entries."""
self.arp.clear()
def send_arp_request(self, target_ip_address: Union[IPv4Address, str]):
def send_arp_request(
self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None
):
"""
Perform a standard ARP request for a given target IP address.
Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP
address.
address. This method can be configured to ignore specific networks when sending out ARP requests,
which is useful in environments where certain addresses should not be queried.
:param target_ip_address: The target IP address to send an ARP request for.
:param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP
request broadcast. Each address in this list indicates a network which will not be queried during the ARP
request process. This is particularly useful in complex network environments where traffic should be
minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being
sent back to their source.
"""
for nic in self.nics.values():
if nic.enabled:
use_nic = True
if ignore_networks:
for ipv4 in ignore_networks:
if ipv4 in nic.ip_network:
use_nic = False
if nic.enabled and use_nic:
self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}")
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
@@ -802,7 +844,6 @@ class ICMP:
self.arp.send_arp_request(frame.ip.src_ip_address)
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
return
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
@@ -817,9 +858,7 @@ class ICMP:
sequence=frame.icmp.sequence + 1,
)
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
frame = Frame(
ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload)
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
src_nic.send_frame(frame)
@@ -1094,12 +1133,12 @@ class Node(SimComponent):
{
"hostname": self.hostname,
"operating_state": self.operating_state.value,
"NICs": {uuid: nic.describe_state() for uuid, nic in self.nics.items()},
"NICs": {eth_num: nic.describe_state() for eth_num, nic in self.ethernet_port.items()},
# "switch_ports": {uuid, sp for uuid, sp in self.switch_ports.items()},
"file_system": self.file_system.describe_state(),
"applications": {uuid: app.describe_state() for uuid, app in self.applications.items()},
"services": {uuid: svc.describe_state() for uuid, svc in self.services.items()},
"process": {uuid: proc.describe_state() for uuid, proc in self.processes.items()},
"applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red,
}
)
@@ -1316,6 +1355,7 @@ class Node(SimComponent):
self.nics[nic.uuid] = nic
self.ethernet_port[len(self.nics)] = nic
nic._connected_node = self
nic._port_num_on_node = len(self.nics)
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
@@ -1442,7 +1482,6 @@ class Node(SimComponent):
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
@@ -1475,7 +1514,6 @@ class Node(SimComponent):
self.applications[application.uuid] = application
application.parent = self
self.sys_log.info(f"Installed application {application.name}")
_LOGGER.info(f"Added application {application.uuid} to node {self.uuid}")
self._application_request_manager.add_request(application.uuid, RequestType(func=application._request_manager))
def uninstall_application(self, application: Application) -> None:

View File

@@ -357,11 +357,10 @@ class RouteEntry(SimComponent):
"""
Represents a single entry in a routing table.
Attributes:
address (IPv4Address): The destination IP address or network address.
subnet_mask (IPv4Address): The subnet mask for the network.
next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded.
metric (int): The cost metric for this route. Default is 0.0.
:ivar address: The destination IP address or network address.
:ivar subnet_mask: The subnet mask for the network.
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
:ivar metric: The cost metric for this route. Default is 0.0.
Example:
>>> entry = RouteEntry(
@@ -381,12 +380,6 @@ class RouteEntry(SimComponent):
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
def __init__(self, **kwargs):
for key in {"address", "subnet_mask", "next_hop_ip_address"}:
if not isinstance(kwargs[key], IPv4Address):
kwargs[key] = IPv4Address(kwargs[key])
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
@@ -421,6 +414,7 @@ class RouteTable(SimComponent):
"""
routes: List[RouteEntry] = []
default_route: Optional[RouteEntry] = None
sys_log: SysLog
def set_original_state(self):
@@ -465,12 +459,35 @@ class RouteTable(SimComponent):
)
self.routes.append(route)
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
"""
Sets the next-hop IP address for the default route in a routing table.
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
the method creates a new default route with the specified next-hop IP address. If a default route already
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
IP address, the method logs this action.
:param ip_address: The next-hop IP address to be set for the default route.
"""
if not self.default_route:
self.default_route = RouteEntry(
ip_address=IPv4Address("0.0.0.0"),
subnet_mask=IPv4Address("0.0.0.0"),
next_hop_ip_address=ip_address,
)
else:
self.default_route.next_hop_ip_address = ip_address
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
"""
Find the best route for a given destination IP.
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
:param destination_ip: The destination IP to find the route for.
:return: The best matching RouteEntry, or None if no route matches.
"""
@@ -490,6 +507,9 @@ class RouteTable(SimComponent):
longest_prefix = prefix_len
lowest_metric = route.metric
if not best_route and self.default_route:
best_route = self.default_route
return best_route
def show(self, markdown: bool = False):
@@ -521,12 +541,26 @@ class RouterARPCache(ARPCache):
super().__init__(sys_log)
self.router: Router = router
def process_arp_packet(self, from_nic: NIC, frame: Frame):
def process_arp_packet(
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
) -> None:
"""
Overridden method to process a received ARP packet in a router-specific way.
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
handling packet TTL (Time To Live).
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
to find the best route and reattempts ARP request processing if needed.
:param from_nic: The NIC that received the ARP packet.
:param frame: The original ARP frame.
:param frame: The frame containing the ARP packet.
:param route_table: The routing table of the router.
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
"""
arp_packet = frame.arp
@@ -554,7 +588,11 @@ class RouterARPCache(ARPCache):
)
arp_packet.sender_mac_addr = nic.mac_address
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
nic.send_frame(frame)
return
# ARP Request
self.sys_log.info(
@@ -565,16 +603,32 @@ class RouterARPCache(ARPCache):
self.add_arp_cache_entry(
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
)
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_packet, from_nic)
# If the target IP matches one of the router's NICs
for nic in self.nics.values():
if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
if arp_packet.target_ip_address in nic.ip_network:
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_reply, from_nic)
return
# Check Route Table
route = route_table.find_best_route(arp_packet.target_ip_address)
if route:
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
if not nic:
if not is_reattempt:
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
else:
self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found")
return
else:
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_reply, from_nic)
return
class RouterICMP(ICMP):
"""
@@ -645,7 +699,7 @@ class RouterICMP(ICMP):
return
# Route the frame
self.router.route_frame(frame, from_nic)
self.router.process_frame(frame, from_nic)
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
for nic in self.router.nics.values():
@@ -665,7 +719,48 @@ class RouterICMP(ICMP):
return
# Route the frame
self.router.route_frame(frame, from_nic)
self.router.process_frame(frame, from_nic)
class RouterNIC(NIC):
"""
A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality.
This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed
to handle network frames in a manner specific to routers, allowing them to efficiently process
and route network traffic.
"""
def receive_frame(self, frame: Frame) -> bool:
"""
Receive and process a network frame from the connected link, provided the NIC is enabled.
This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL
expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for
this NIC's MAC address or is a broadcast frame.
Key Differences from Standard NIC:
- Does not perform Layer 3 (IP-based) broadcast checks.
- Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames.
:param frame: The network frame being received. This should be an instance of the Frame class.
:return: Returns True if the frame is processed and passed to the connected node, False otherwise.
"""
if self.enabled:
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
return False
frame.set_received_timestamp()
self.pcap.capture(frame)
# If this destination or is broadcast
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
self._connected_node.receive_frame(frame=frame, from_nic=self)
return True
return False
def __str__(self) -> str:
return f"{self.mac_address}/{self.ip_address}"
class Router(Node):
@@ -678,7 +773,7 @@ class Router(Node):
"""
num_ports: int
ethernet_ports: Dict[int, NIC] = {}
ethernet_ports: Dict[int, RouterNIC] = {}
acl: AccessControlList
route_table: RouteTable
arp: RouterARPCache
@@ -697,7 +792,7 @@ class Router(Node):
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
for i in range(1, self.num_ports + 1):
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(nic)
self.ethernet_ports[i] = nic
@@ -752,9 +847,9 @@ class Router(Node):
state["acl"] = self.acl.describe_state()
return state
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
"""
Route a given frame from a source NIC to its destination.
Process a Frame.
:param frame: The frame to be routed.
:param from_nic: The source network interface.
@@ -769,25 +864,57 @@ class Router(Node):
return
if not nic:
self.arp.send_arp_request(frame.ip.dst_ip_address)
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
self.arp.send_arp_request(
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
)
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
if not nic.enabled:
# TODO: Add sys_log here
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
return
if frame.ip.dst_ip_address in nic.ip_network:
from_port = self._get_port_of_nic(from_nic)
to_port = self._get_port_of_nic(nic)
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
frame.ethernet.src_mac_addr = nic.mac_address
frame.ethernet.dst_mac_addr = target_mac
nic.send_frame(frame)
return
else:
pass
# TODO: Deal with routing from route tables
self._route_frame(frame, from_nic)
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
if route:
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
if re_attempt and not nic:
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
return
if not nic:
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
if not nic.enabled:
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
return
from_port = self._get_port_of_nic(from_nic)
to_port = self._get_port_of_nic(nic)
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
frame.ethernet.src_mac_addr = nic.mac_address
frame.ethernet.dst_mac_addr = target_mac
nic.send_frame(frame)
def receive_frame(self, frame: Frame, from_nic: NIC):
"""
@@ -796,7 +923,7 @@ class Router(Node):
:param frame: The incoming frame.
:param from_nic: The network interface where the frame is coming from.
"""
route_frame = False
process_frame = False
protocol = frame.ip.protocol
src_ip_address = frame.ip.src_ip_address
dst_ip_address = frame.ip.dst_ip_address
@@ -828,12 +955,12 @@ class Router(Node):
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
else:
if src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
else:
# All other traffic
route_frame = True
if route_frame:
self.route_frame(frame, from_nic)
process_frame = True
if process_frame:
self.process_frame(frame, from_nic)
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
"""

View File

@@ -30,6 +30,7 @@ class Switch(Node):
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port._connected_node = self
port._port_num_on_node = port_num
port.parent = self
port.port_num = port_num
@@ -89,12 +90,12 @@ class Switch(Node):
self._add_mac_table_entry(src_mac, incoming_port)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
if port != incoming_port:
if port.enabled and port != incoming_port:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):

View File

@@ -9,10 +9,10 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.web_server.web_server import WebServer
@@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network:
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None) # noqa
database_service._process_sql(ddl, None, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None) # noqa
database_service._process_sql(insert_statement, None, None) # noqa
# Web Server
web_server = Server(

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
class NTPReply(BaseModel):
"""Represents a NTP Reply packet."""
ntp_datetime: datetime
"NTP datetime object set by NTP Server."
class NTPPacket(DataPacket):
"""
Represents the NTP layer of a network frame.
:param ntp_request: NTPRequest packet from NTP client.
:param ntp_reply: NTPReply packet from NTP Server.
"""
ntp_reply: Optional[NTPReply] = None
def generate_reply(self, ntp_server_time: datetime) -> NTPPacket:
"""Generate a NTPPacket containing the time in a NTPReply object.
:param time: datetime object representing the time from the NTP server.
:return: A new NTPPacket object.
"""
self.ntp_reply = NTPReply(ntp_datetime=ntp_server_time)
return self

View File

@@ -5,7 +5,7 @@ from uuid import uuid4
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
@@ -23,7 +23,6 @@ class DatabaseClient(Application):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
def __init__(self, **kwargs):
@@ -66,18 +65,24 @@ class DatabaseClient(Application):
self.server_password = server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self) -> bool:
def connect(self, connection_id: Optional[str] = None) -> bool:
"""Connect to a Database Service."""
if not self._can_perform_action():
return False
if not self.connected:
return self._connect(self.server_ip_address, self.server_password)
# already connected
return True
if not connection_id:
connection_id = str(uuid4())
return self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
)
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
self,
server_ip_address: IPv4Address,
connection_id: Optional[str] = None,
password: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""
Connects the DatabaseClient to the DatabaseServer.
@@ -92,33 +97,58 @@ class DatabaseClient(Application):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connected:
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
if self.connections.get(connection_id):
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
)
self.server_ip_address = server_ip_address
return self.connected
return True
else:
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
)
return False
payload = {"type": "connect_request", "password": password}
payload = {
"type": "connect_request",
"password": password,
"connection_id": connection_id,
}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(server_ip_address, password, True)
return self._connect(
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
)
def disconnect(self):
def disconnect(self, connection_id: Optional[str] = None) -> bool:
"""Disconnect from the Database Service."""
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
)
if not self._can_perform_action():
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
return False
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
self.server_ip_address = None
self.connected = False
# if there are no connections - nothing to disconnect
if not len(self.connections):
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
return False
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
# if no connection provided, disconnect the first connection
if not connection_id:
connection_id = list(self.connections.keys())[0]
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
self.remove_connection(connection_id=connection_id)
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
)
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
"""
Send a query to the connected database server.
@@ -141,19 +171,17 @@ class DatabaseClient(Application):
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id},
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
def query(self, sql: str, is_reattempt: bool = False) -> bool:
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
"""
Send a query to the Database Service.
@@ -164,20 +192,17 @@ class DatabaseClient(Application):
if not self._can_perform_action():
return False
if self.connected:
query_id = str(uuid4())
if connection_id is None:
connection_id = str(uuid4())
if not self.connections.get(connection_id):
if not self.connect(connection_id=connection_id):
return False
# Initialise the tracker of this ID to False
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
else:
if is_reattempt:
return False
if not self.connect():
return False
self.query(sql=sql, is_reattempt=True)
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
@@ -192,13 +217,13 @@ class DatabaseClient(Application):
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
self.connected = payload["response"] == True
if payload["response"] is True:
# add connection
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
elif payload["type"] == "sql":
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
_LOGGER.debug(f"Received payload {payload}")
else:
self.connected = False
return True

View File

@@ -5,7 +5,6 @@ from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
@@ -149,9 +148,9 @@ class DataManipulationBot(DatabaseClient):
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing data manipulation")
# perform the attack
if not self.connected:
if not len(self.connections):
self.connect()
if self.connected:
if len(self.connections):
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
@@ -183,9 +182,9 @@ class DataManipulationBot(DatabaseClient):
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
if not self._can_perform_action():
return
if self.server_ip_address and self.payload and self.operating_state:
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self._logon()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)

View File

@@ -0,0 +1,192 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class DoSAttackStage(IntEnum):
"""Enum representing the different stages of a Denial of Service attack."""
NOT_STARTED = 0
"Attack not yet started."
PORT_SCAN = 1
"Attack is in discovery stage - checking if provided ip and port are open."
ATTACKING = 2
"Denial of Service attack is in progress."
COMPLETED = 3
"Attack is completed."
class DoSBot(DatabaseClient, Application):
"""A bot that simulates a Denial of Service attack."""
target_ip_address: Optional[IPv4Address] = None
"""IP address of the target service."""
target_port: Optional[Port] = None
"""Port of the target service."""
payload: Optional[str] = None
"""Payload to deliver to the target service as part of the denial of service attack."""
repeat: bool = False
"""If true, the Denial of Service bot will keep performing the attack."""
attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED
"""Current stage of the DoS kill chain."""
port_scan_p_of_success: float = 0.1
"""Probability of port scanning being sucessful."""
dos_intensity: float = 1.0
"""How much of the max sessions will be used by the DoS when attacking."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
def set_original_state(self):
"""Set the original state of the Denial of Service Bot."""
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"target_ip_address",
"target_port",
"payload",
"repeat",
"attack_stage",
"max_sessions",
"port_scan_p_of_success",
"dos_intensity",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
return rm
def configure(
self,
target_ip_address: IPv4Address,
target_port: Optional[Port] = Port.POSTGRES_SERVER,
payload: Optional[str] = None,
repeat: bool = False,
port_scan_p_of_success: float = 0.1,
dos_intensity: float = 1.0,
max_sessions: int = 1000,
):
"""
Configure the Denial of Service bot.
:param: target_ip_address: The IP address of the Node containing the target service.
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
:param: dos_intensity: The intensity of the DoS attack.
Multiplied with the application's max session - Default is 1.0
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
"""
self.target_ip_address = target_ip_address
self.target_port = target_port
self.payload = payload
self.repeat = repeat
self.port_scan_p_of_success = port_scan_p_of_success
self.dos_intensity = dos_intensity
self.max_sessions = max_sessions
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
)
def run(self):
"""Run the Denial of Service Bot."""
super().run()
self._application_loop()
def _application_loop(self):
"""
The main application loop for the Denial of Service bot.
The loop goes through the stages of a DoS attack.
"""
if not self._can_perform_action():
return
# DoS bot cannot do anything without a target
if not self.target_ip_address or not self.target_port:
self.sys_log.error(
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
)
return
self.clear_connections()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
self._perform_dos()
if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING:
self.attack_stage = DoSAttackStage.NOT_STARTED
else:
self.attack_stage = DoSAttackStage.COMPLETED
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PORT_SCAN` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == DoSAttackStage.NOT_STARTED:
# perform a port scan to identify that the SQL port is open on the server
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the port scan
port_is_open = True # Temporary; later we can implement NMAP port scan.
if port_is_open:
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DoSAttackStage.PORT_SCAN
def _perform_dos(self):
"""
Perform the Denial of Service attack.
DoSBot does this by clogging up the available connections to a service.
"""
if not self.attack_stage == DoSAttackStage.PORT_SCAN:
return
self.attack_stage = DoSAttackStage.ATTACKING
self.server_ip_address = self.target_ip_address
self.port = self.target_port
dos_sessions = int(float(self.max_sessions) * self.dos_intensity)
for i in range(dos_sessions):
self.connect()
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, iterate through the application loop.
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from ipaddress import IPv4Address
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
@@ -141,41 +141,76 @@ class SessionManager:
def receive_payload_from_software_manager(
self,
payload: Any,
dst_ip_address: Optional[IPv4Address] = None,
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Union[Any, None]:
"""
Receive a payload from the SoftwareManager.
Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission.
If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``.
This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an
IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP.
A new session is established if `session_id` is not provided, and an existing session is used otherwise.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
:param dst_ip_address: The destination IP address or network for broadcast. Optional.
:param dst_port: The destination port for the TCP packet. Optional.
:param session_id: The Session ID from which the payload originates. Optional.
:param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False.
:return: The outcome of sending the frame, or None if sending was unsuccessful.
"""
is_broadcast = False
outbound_nic = None
dst_mac_address = None
# Use session details if session_id is provided
if session_id:
session = self.sessions_by_uuid[session_id]
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
dst_port = self.sessions_by_uuid[session_id].dst_port
dst_ip_address = session.with_ip_address
dst_port = session.dst_port
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
# Determine if the payload is for broadcast or unicast
if dst_mac_address:
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
# Handle broadcast transmission
if isinstance(dst_ip_address, IPv4Network):
is_broadcast = True
dst_ip_address = dst_ip_address.broadcast_address
if dst_ip_address:
# Find a suitable NIC for the broadcast
for nic in self.arp_cache.nics.values():
if dst_ip_address in nic.ip_network and nic.enabled:
dst_mac_address = "ff:ff:ff:ff:ff:ff"
outbound_nic = nic
else:
if not is_reattempt:
self.arp_cache.send_arp_request(dst_ip_address)
return self.receive_payload_from_software_manager(
payload=payload,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
else:
return
# Resolve MAC address for unicast transmission
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
# Resolve outbound NIC for unicast transmission
if dst_mac_address:
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
# If MAC address not found, initiate ARP request
else:
if not is_reattempt:
self.arp_cache.send_arp_request(dst_ip_address)
# Reattempt payload transmission after ARP request
return self.receive_payload_from_software_manager(
payload=payload,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
else:
# Return None if reattempt fails
return
# Check if outbound NIC and destination MAC address are resolved
if not outbound_nic or not dst_mac_address:
return False
# Construct the frame for transmission
frame = Frame(
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
ip=IPPacket(
@@ -189,15 +224,17 @@ class SessionManager:
payload=payload,
)
if not session_id:
# Manage session for unicast transmission
if not (is_broadcast and session_id):
session_key = self._get_session_key(frame, inbound_frame=False)
session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
# Create a new session if it doesn't exist
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
# Send the frame through the NIC
return outbound_nic.send_frame(frame)
def receive_frame(self, frame: Frame):

View File

@@ -1,4 +1,4 @@
from ipaddress import IPv4Address
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
@@ -130,20 +130,28 @@ class SoftwareManager:
def send_payload_to_session_manager(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
) -> bool:
"""
Send a payload to the SessionManager.
Sends a payload to the SessionManager for network transmission.
This method is responsible for initiating the process of sending network payloads. It supports both
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
as an IPv4Network.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
:param dest_port: The destination port for the payload. Optional.
:param session_id: The Session ID from which the payload originates. Optional.
:return: True if the payload was successfully sent, False otherwise.
"""
return self.session_manager.receive_payload_from_software_manager(
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
payload=payload,
dst_ip_address=dest_ip_address,
dst_port=dest_port,
session_id=session_id,
)
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):

View File

@@ -41,5 +41,5 @@ class Process(Software):
:rtype: Dict
"""
state = super().describe_state()
state.update({"operating_state": self.operating_state.name})
state.update({"operating_state": self.operating_state.value})
return state

View File

@@ -1,4 +1,3 @@
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
@@ -24,7 +23,6 @@ class DatabaseService(Service):
"""
password: Optional[str] = None
connections: Dict[str, datetime] = {}
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
@@ -58,7 +56,7 @@ class DatabaseService(Service):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
self.clear_connections()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
@@ -151,24 +149,39 @@ class DatabaseService(Service):
return self.file_system.get_folder_by_id(self.db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
self, connection_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
)
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
# try to create connection
if not self.add_connection(connection_id=connection_id):
status_code = 500
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
return {
"status_code": status_code,
"type": "connect_response",
"response": status_code == 200,
"connection_id": connection_id,
}
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
def _process_sql(
self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None
) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
@@ -180,14 +193,21 @@ class DatabaseService(Service):
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
if query == "SELECT":
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
return {
"status_code": 200,
"type": "sql",
"data": True,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
else:
# Invalid query
return {"status_code": 500, "data": False}
@@ -211,19 +231,25 @@ class DatabaseService(Service):
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
result = {"status_code": 500, "data": []}
# if server service is down, return error
if not self._can_perform_action():
return False
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
result = self._process_connect(session_id=session_id, password=payload.get("password"))
result = self._process_connect(
connection_id=payload.get("connection_id"), password=payload.get("password")
)
elif payload["type"] == "disconnect":
if session_id in self.connections:
self.connections.pop(session_id)
if payload["connection_id"] in self.connections:
self.remove_connection(connection_id=payload["connection_id"])
elif payload["type"] == "sql":
if session_id in self.connections:
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
if payload.get("connection_id") in self.connections:
result = self._process_sql(
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
)
else:
result = {"status_code": 401, "type": "sql"}
self.send(payload=result, session_id=session_id)

View File

@@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC):
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
connected: bool = False
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPClient"
kwargs["port"] = Port.FTP
@@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC):
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
)
if payload.status_code == FTPStatusCode.OK:
self.connected = False
return True
return False
return payload.status_code == FTPStatusCode.OK
def send_file(
self,
@@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC):
return False
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
if not len(self.connections):
return False
else:
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
@@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC):
:type: dest_port: Optional[Port]
"""
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
if not len(self.connections):
return False
else:
# send retrieve request
@@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC):
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
return False
# if PORT succeeded, add the connection as an active connection list
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
self.add_connection(connection_id=session_id, session_id=session_id)
# if QUIT succeeded, remove the session from active connection list
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
self.remove_connection(connection_id=session_id)
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
self._process_ftp_command(payload=payload, session_id=session_id)

View File

@@ -1,5 +1,4 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
@@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC):
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
connections: Dict[str, IPv4Address] = {}
"""Current active connections to the FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = Port.FTP
@@ -41,7 +37,7 @@ class FTPServer(FTPServiceABC):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.connections.clear()
self.clear_connections()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
@@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC):
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
if session_id:
session_details = self._get_session_details(session_id)
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
@@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC):
# check that the port is valid
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
# return successful connection
self.connections[session_id] = session_details.with_ip_address
self.add_connection(connection_id=session_id, session_id=session_id)
payload.status_code = FTPStatusCode.OK
return payload
@@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC):
return payload
if payload.ftp_command == FTPCommand.QUIT:
self.connections.pop(session_id)
self.remove_connection(connection_id=session_id)
payload.status_code = FTPStatusCode.OK
return payload

View File

@@ -0,0 +1,132 @@
from datetime import datetime
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service, ServiceOperatingState
_LOGGER = getLogger(__name__)
class NTPClient(Service):
"""Represents a NTP client as a service."""
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None
def __init__(self, **kwargs):
kwargs["name"] = "NTPClient"
kwargs["port"] = Port.NTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def configure(self, ntp_server_ip_address: IPv4Address) -> None:
"""
Set the IP address for the NTP server.
:param ntp_server_ip_address: IPv4 address of NTP server.
:param ntp_client_ip_Address: IPv4 address of NTP client.
"""
self.ntp_server = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state
of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def send(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
dest_ip_address: IPv4Address = None,
dest_port: [Port] = Port.NTP,
**kwargs,
) -> bool:
"""Requests NTP data from NTP server.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:return: True if successful, False otherwise.
"""
return super().send(
payload=payload,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
session_id=session_id,
**kwargs,
)
def receive(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""Receives time data from server.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
if not isinstance(payload, NTPPacket):
_LOGGER.debug(f"{payload} is not a NTPPacket")
return False
if payload.ntp_reply.ntp_datetime:
self.sys_log.info(
f"{self.name}: \
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
)
self.time = payload.ntp_reply.ntp_datetime
return True
def request_time(self) -> None:
"""Send request to ntp_server."""
ntp_server_packet = NTPPacket()
self.send(payload=ntp_server_packet, dest_ip_address=self.ntp_server)
def apply_timestep(self, timestep: int) -> None:
"""
For each timestep request the time from the NTP server.
In this instance, if any multi-timestep processes are currently
occurring (such as restarting or installation), then they are brought one step closer to
being finished.
:param timestep: The current timestep number. (Amount of time since simulation episode began)
:type timestep: int
"""
self.sys_log.info(f"{self.name} apply_timestep")
super().apply_timestep(timestep)
if self.operating_state == ServiceOperatingState.RUNNING:
# request time from server
self.request_time()
else:
self.sys_log.debug(f"{self.name} ntp client not running")

View File

@@ -0,0 +1,73 @@
from datetime import datetime
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class NTPServer(Service):
"""Represents a NTP server as a service."""
def __init__(self, **kwargs):
kwargs["name"] = "NTPServer"
kwargs["port"] = Port.NTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current
state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including
resetting any stateful properties or statistics, and clearing any message
queues.
"""
pass
def receive(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a request from NTPClient.
Check that request has a valid IP address.
:param payload: The payload to send.
:param session_id: Id of the session (Optional).
:return: True if valid NTP request else False.
"""
if not (isinstance(payload, NTPPacket)):
_LOGGER.debug(f"{payload} is not a NTPPacket")
return False
payload: NTPPacket = payload
# generate a reply with the current time
time = datetime.now()
payload = payload.generate_reply(time)
# send reply
self.send(payload, session_id)
return True

View File

@@ -41,6 +41,9 @@ class Service(IOSoftware):
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _can_perform_action(self) -> bool:
"""
Checks if the service can perform actions.
@@ -53,7 +56,7 @@ class Service(IOSoftware):
if not super()._can_perform_action():
return False
if self.operating_state is not self.operating_state.RUNNING:
if self.operating_state is not ServiceOperatingState.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
@@ -75,9 +78,6 @@ class Service(IOSoftware):
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()

View File

@@ -1,7 +1,9 @@
import copy
from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Union
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -232,7 +234,7 @@ class IOSoftware(Software):
installing_count: int = 0
"The number of times the software has been installed. Default is 0."
max_sessions: int = 1
max_sessions: int = 100
"The maximum number of sessions that the software can handle simultaneously. Default is 0."
tcp: bool = True
"Indicates if the software uses TCP protocol for communication. Default is True."
@@ -240,6 +242,8 @@ class IOSoftware(Software):
"Indicates if the software uses UDP protocol for communication. Default is True."
port: Port
"The port to which the software is connected."
_connections: Dict[str, Dict] = {}
"Active connections."
def set_original_state(self):
"""Sets the original state."""
@@ -284,23 +288,85 @@ class IOSoftware(Software):
return False
return True
@property
def connections(self) -> Dict[str, Dict]:
"""Return the public version of connections."""
return copy.copy(self._connections)
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
"""
Create a new connection to this service.
Returns true if connection successfully created
:param: connection_id: UUID of the connection to create
:type: string
"""
# if over or at capacity, set to overwhelmed
if len(self._connections) >= self.max_sessions:
self.set_health_state(SoftwareHealthState.OVERWHELMED)
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
return False
else:
# if service was previously overwhelmed, set to good because there is enough space for connections
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.set_health_state(SoftwareHealthState.GOOD)
# check that connection already doesn't exist
if not self._connections.get(connection_id):
session_details = None
if session_id:
session_details = self._get_session_details(session_id)
self._connections[connection_id] = {
"ip_address": session_details.with_ip_address if session_details else None,
"time": datetime.now(),
}
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
return True
# connection with given id already exists
self.sys_log.error(
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
)
return False
def remove_connection(self, connection_id: str) -> bool:
"""
Remove a connection from this service.
Returns true if connection successfully removed
:param: connection_id: UUID of the connection to create
:type: string
"""
if self.connections.get(connection_id):
self._connections.pop(connection_id)
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
return True
def clear_connections(self):
"""Clears all the connections from the software."""
self._connections = {}
def send(
self,
payload: Any,
session_id: Optional[str] = None,
dest_ip_address: Optional[IPv4Address] = None,
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
dest_port: Optional[Port] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
Sends a payload to the SessionManager for network transmission.
This method is responsible for initiating the process of sending network payloads. It supports both
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
as an IPv4Network. It delegates the actual sending process to the SoftwareManager.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
:param dest_port: The destination port for the payload. Optional.
:param session_id: The Session ID from which the payload originates. Optional.
:return: True if the payload was successfully sent, False otherwise.
"""
if not self._can_perform_action():
return False

View File

@@ -93,25 +93,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -126,7 +126,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -514,7 +514,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -522,8 +522,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -31,13 +31,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -104,25 +97,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,7 +130,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -525,7 +518,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -533,8 +526,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -37,13 +37,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -111,25 +104,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -144,7 +137,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -532,7 +525,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -540,8 +533,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
@@ -559,25 +552,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -592,7 +585,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -980,7 +973,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -988,8 +981,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -35,13 +35,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
@@ -109,25 +102,25 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
- node_hostname: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
# - service_name: backup_service
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -142,7 +135,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -530,7 +523,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -538,8 +531,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_server_web_service
agent_settings:

View File

@@ -105,25 +105,23 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_ref: database_service
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -138,7 +136,7 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
@@ -526,7 +524,7 @@ agents:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
@@ -534,8 +532,8 @@ agents:
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: web_server
service_name: web_service
agent_settings:

View File

@@ -170,7 +170,7 @@ def example_network() -> Network:
-------------- --------------
| client_1 |----- ----| server_1 |
-------------- | -------------- -------------- -------------- | --------------
------| switch_1 |------| router_1 |------| switch_2 |------
------| switch_2 |------| router_1 |------| switch_1 |------
-------------- | -------------- -------------- -------------- | --------------
| client_2 |---- ----| server_2 |
-------------- --------------

View File

@@ -1,8 +1,8 @@
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_data_manipulation(uc2_network):

View File

@@ -14,7 +14,7 @@ def test_file_observation():
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})

View File

@@ -0,0 +1,180 @@
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
class BroadcastService(Service):
"""A service for sending broadcast and unicast messages over a network."""
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
# Implement state description for the service
pass
def unicast(self, ip_address: IPv4Address):
# Send a unicast payload to a specific IP address
super().send(
payload="unicast",
dest_ip_address=ip_address,
dest_port=Port.HTTP,
)
def broadcast(self, ip_network: IPv4Network):
# Send a broadcast payload to an entire IP network
super().send(
payload="broadcast",
dest_ip_address=ip_network,
dest_port=Port.HTTP,
)
class BroadcastClient(Application):
"""A client application to receive broadcast and unicast messages."""
payloads_received: List = []
def __init__(self, **kwargs):
# Set default client properties
kwargs["name"] = "BroadcastClient"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
# Implement state description for the application
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
# Append received payloads to the list and print a message
self.payloads_received.append(payload)
print(f"Payload: {payload} received on node {self.sys_log.hostname}")
@pytest.fixture(scope="function")
def broadcast_network() -> Network:
network = Network()
client_1 = Computer(
hostname="client_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_1.power_on()
client_1.software_manager.install(BroadcastClient)
application_1 = client_1.software_manager.software["BroadcastClient"]
application_1.run()
client_2 = Computer(
hostname="client_2",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(BroadcastClient)
application_2 = client_2.software_manager.software["BroadcastClient"]
application_2.run()
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.1",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_1.power_on()
server_1.software_manager.install(BroadcastService)
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1])
network.connect(endpoint_a=client_2.ethernet_port[1], endpoint_b=switch_1.switch_ports[2])
network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[3])
return network
@pytest.fixture(scope="function")
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"BroadcastClient"
]
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"BroadcastClient"
]
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"BroadcastService"
]
return service, client_1, client_2
def test_broadcast_correct_subnet(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.broadcast(IPv4Network("192.168.1.0/24"))
assert client_1.payloads_received == ["broadcast"]
assert client_2.payloads_received == ["broadcast"]
def test_broadcast_incorrect_subnet(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.broadcast(IPv4Network("192.168.2.0/24"))
assert not client_1.payloads_received
assert not client_2.payloads_received
def test_unicast_correct_address(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.unicast(IPv4Address("192.168.1.2"))
assert client_1.payloads_received == ["unicast"]
assert not client_2.payloads_received
def test_unicast_incorrect_address(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.unicast(IPv4Address("192.168.2.2"))
assert not client_1.payloads_received
assert not client_2.payloads_received

View File

@@ -1,11 +1,16 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
@pytest.fixture(scope="function")
@@ -34,6 +39,69 @@ def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]:
return pc_a, pc_b, router_1
@pytest.fixture(scope="function")
def multi_hop_network() -> Network:
network = Network()
# Configure PC A
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
)
pc_a.power_on()
network.add_node(pc_a)
# Configure Router 1
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1.power_on()
network.add_node(router_1)
# Configure the connection between PC A and Router 1 port 2
router_1.configure_port(2, "192.168.0.1", "255.255.255.0")
network.connect(pc_a.ethernet_port[1], router_1.ethernet_ports[2])
router_1.enable_port(2)
# Configure Router 1 ACLs
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
# Configure PC B
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
pc_b.power_on()
network.add_node(pc_b)
# Configure Router 2
router_2 = Router(hostname="router_2", start_up_duration=0)
router_2.power_on()
network.add_node(router_2)
# Configure the connection between PC B and Router 2 port 2
router_2.configure_port(2, "192.168.2.1", "255.255.255.0")
network.connect(pc_b.ethernet_port[1], router_2.ethernet_ports[2])
router_2.enable_port(2)
# Configure Router 2 ACLs
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_2.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
# Configure the connection between Router 1 port 1 and Router 2 port 1
router_2.configure_port(1, "192.168.1.2", "255.255.255.252")
router_1.configure_port(1, "192.168.1.1", "255.255.255.252")
network.connect(router_1.ethernet_ports[1], router_2.ethernet_ports[1])
router_1.enable_port(1)
router_2.enable_port(1)
return network
def test_ping_default_gateway(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
@@ -50,3 +118,68 @@ def test_host_on_other_subnet(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping("192.168.1.10")
def test_no_route_no_ping(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
assert not pc_a.ping(pc_b.ethernet_port[1].ip_address)
def test_with_routes_can_ping(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
# Configure Route from Router 1 to PC B subnet
router_1.route_table.add_route(
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
)
# Configure Route from Router 2 to PC A subnet
router_2.route_table.add_route(
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
)
assert pc_a.ping(pc_b.ethernet_port[1].ip_address)
def test_routing_services(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
pc_a.software_manager.install(NTPClient)
ntp_client = pc_a.software_manager.software["NTPClient"]
ntp_client.start()
pc_b.software_manager.install(NTPServer)
pc_b.software_manager.software["NTPServer"].start()
ntp_client.configure(ntp_server_ip_address=pc_b.ethernet_port[1].ip_address)
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
assert ntp_client.time is None
ntp_client.request_time()
assert ntp_client.time is None
# Configure Route from Router 1 to PC B subnet
router_1.route_table.add_route(
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
)
# Configure Route from Router 2 to PC A subnet
router_2.route_table.add_route(
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
)
ntp_client.request_time()
assert ntp_client.time is not None

View File

@@ -0,0 +1,180 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]:
computer, server = client_server
# Install DoSBot on computer
computer.software_manager.install(DoSBot)
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
target_port=Port.POSTGRES_SERVER,
)
# Install DB Server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
return dos_bot, computer, db_server_service, server
@pytest.fixture(scope="function")
def dos_bot_db_server_green_client(example_network) -> Network:
network: Network = example_network
router_1: Router = example_network.get_node_by_hostname("router_1")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
client_1: Computer = network.get_node_by_hostname("client_1")
client_2: Computer = network.get_node_by_hostname("client_2")
server: Server = network.get_node_by_hostname("server_1")
# install DoS bot on client 1
client_1.software_manager.install(DoSBot)
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
dos_bot.configure(
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
target_port=Port.POSTGRES_SERVER,
)
# install db server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
# Install DB client (green) on client 2
client_2.software_manager.install(DatabaseClient)
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
database_client.run()
return network
def test_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = True
dos_bot.run()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
db_server_service.clear_connections()
db_server_service.set_health_state(SoftwareHealthState.GOOD)
assert len(db_server_service.connections) == 0
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
def test_non_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = False
dos_bot.run()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
db_server_service.clear_connections()
db_server_service.set_health_state(SoftwareHealthState.GOOD)
assert len(db_server_service.connections) == 0
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert len(dos_bot.connections) == 0
assert len(db_server_service.connections) == 0
assert len(dos_bot.connections) == 0
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
def test_dos_bot_database_service_connection(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
dos_bot.operating_state = ApplicationOperatingState.RUNNING
dos_bot.attack_stage = DoSAttackStage.PORT_SCAN
dos_bot._perform_dos()
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client):
network: Network = dos_bot_db_server_green_client
client_1: Computer = network.get_node_by_hostname("client_1")
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
client_2: Computer = network.get_node_by_hostname("client_2")
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
server: Server = network.get_node_by_hostname("server_1")
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
dos_bot.port_scan_p_of_success = 1
dos_bot.repeat = False
dos_bot.run()
# DoS bot fills up connection of db server service
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(db_server_service.connections) == db_server_service.max_sessions
assert len(dos_bot.connections) == db_server_service.max_sessions
assert len(green_db_client.connections) == 0
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
# db server service is overwhelmed
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
# green agent tries to connect but fails because service is overwhelmed
assert green_db_client.connect() is False
assert len(green_db_client.connections) == 0

View File

@@ -65,8 +65,8 @@ def test_server_turns_off_application(populated_node):
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_application_cannot_be_turned_on_when_server_is_off(populated_node):
"""Check that the application cannot be started when the server is off."""
def test_application_cannot_be_turned_on_when_computer_is_off(populated_node):
"""Check that the application cannot be started when the computer is off."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
@@ -86,8 +86,8 @@ def test_application_cannot_be_turned_on_when_server_is_off(populated_node):
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_server_turns_on_application(populated_node):
"""Check that turning on the server turns on application."""
def test_computer_runs_applications(populated_node):
"""Check that turning on the computer will turn on applications."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON

View File

@@ -1,6 +1,9 @@
from ipaddress import IPv4Address
from typing import Tuple
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
import pytest
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
@@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
@pytest.fixture(scope="function")
def peer_to_peer() -> Tuple[Node, Node]:
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
node_a.connect_nic(nic_a)
node_a.software_manager.get_open_ports()
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
Link(endpoint_a=nic_a, endpoint_b=nic_b)
assert node_a.ping("192.168.0.11")
node_a.software_manager.install(DatabaseClient)
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
node_a.software_manager.software["DatabaseClient"].run()
node_b.software_manager.install(DatabaseService)
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
database_service.start()
return node_a, node_b
@pytest.fixture(scope="function")
def peer_to_peer_secure_db() -> Tuple[Node, Node]:
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
node_a.connect_nic(nic_a)
node_a.software_manager.get_open_ports()
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
Link(endpoint_a=nic_a, endpoint_b=nic_b)
assert node_a.ping("192.168.0.11")
node_a.software_manager.install(DatabaseClient)
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
node_a.software_manager.software["DatabaseClient"].run()
node_b.software_manager.install(DatabaseService)
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
database_service.password = "12345"
database_service.start()
return node_a, node_b
def test_database_client_server_connection(peer_to_peer):
node_a, node_b = peer_to_peer
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
db_client.connect()
assert len(db_client.connections) == 1
assert len(db_service.connections) == 1
db_client.disconnect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
def test_database_client_server_correct_password(peer_to_peer_secure_db):
node_a, node_b = peer_to_peer_secure_db
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
db_service.password = "12345"
assert db_client.connect()
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345")
db_client.connect()
assert len(db_client.connections) == 1
assert len(db_service.connections) == 1
def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
def test_database_client_server_incorrect_password(peer_to_peer_secure_db):
node_a, node_b = peer_to_peer_secure_db
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
db_service.password = "12345"
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
assert not db_client.connect()
# should fail
db_client.connect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass")
db_client.connect()
assert len(db_client.connections) == 0
assert len(db_service.connections) == 0
def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT")
@@ -66,13 +121,13 @@ def test_database_client_query(uc2_network):
def test_create_database_backup(uc2_network):
"""Run the backup_database method and check if the FTP server has the relevant file."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
# back up should be created
assert db_service.backup_database() is True
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
# backup file should exist in the backup server
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
@@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network):
def test_restore_backup(uc2_network):
"""Run the restore_backup method and check if the backup is properly restored."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
# create a back up
assert db_service.backup_database() is True
@@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
assert len(db_client.connections)
assert db_client.query("SELECT") is True

View File

@@ -0,0 +1,86 @@
from ipaddress import IPv4Address
from time import sleep
from typing import Tuple
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
# Create simple network for testing
# Define one node to be an NTP server and another node to be a NTP Client.
@pytest.fixture(scope="function")
def create_ntp_network(client_server) -> Tuple[NTPClient, Computer, NTPServer, Server]:
"""
+------------+ +------------+
| ntp | | ntp |
| client_1 +------------+ server_1 |
| | | |
+------------+ +------------+
"""
client, server = client_server
server.power_on()
server.software_manager.install(NTPServer)
ntp_server: NTPServer = server.software_manager.software.get("NTPServer")
ntp_server.start()
client.power_on()
client.software_manager.install(NTPClient)
ntp_client: NTPClient = client.software_manager.software.get("NTPClient")
ntp_client.start()
return ntp_client, client, ntp_server, server
def test_ntp_client_server(create_ntp_network):
ntp_client, client, ntp_server, server = create_ntp_network
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
assert ntp_client.time is None
ntp_client.request_time()
assert ntp_client.time is not None
first_time = ntp_client.time
sleep(0.1)
ntp_client.apply_timestep(1) # Check time advances
second_time = ntp_client.time
assert first_time < second_time
# Test ntp client behaviour when ntp server is unavailable.
def test_ntp_server_failure(create_ntp_network):
ntp_client, client, ntp_server, server = create_ntp_network
ntp_server: NTPServer = server.software_manager.software["NTPServer"]
ntp_client: NTPClient = client.software_manager.software["NTPClient"]
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
assert ntp_client.operating_state == ServiceOperatingState.RUNNING
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.2"))
# Turn off ntp server.
ntp_server.stop()
assert ntp_server.operating_state == ServiceOperatingState.STOPPED
# And request a time update.
ntp_client.request_time()
assert ntp_client.time is None
# Restart ntp server.
ntp_server.start()
assert ntp_server.operating_state == ServiceOperatingState.RUNNING
ntp_client.request_time()
assert ntp_client.time is not None

View File

@@ -53,12 +53,12 @@ def test_node_os_scan(node, service, application):
# TODO implement processes
# add services to node
service.health_state_actual = SoftwareHealthState.COMPROMISED
service.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_service(service=service)
assert service.health_state_visible == SoftwareHealthState.UNUSED
# add application to node
application.health_state_actual = SoftwareHealthState.COMPROMISED
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.health_state_visible == SoftwareHealthState.UNUSED
@@ -101,7 +101,7 @@ def test_node_red_scan(node, service, application):
assert service.revealed_to_red is False
# add application to node
application.health_state_actual = SoftwareHealthState.COMPROMISED
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.revealed_to_red is False

View File

@@ -5,7 +5,7 @@ from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
DataManipulationAttackStage,
DataManipulationBot,
)
@@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert dm_bot.connected
assert len(dm_bot.connections)

View File

@@ -0,0 +1,90 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
@pytest.fixture(scope="function")
def dos_bot() -> DoSBot:
computer = Computer(
hostname="compromised_pc",
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
operating_state=NodeOperatingState.ON,
)
computer.software_manager.install(DoSBot)
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
dos_bot.set_original_state()
return dos_bot
def test_dos_bot_creation(dos_bot):
"""Test that the DoS bot is installed on a node."""
assert dos_bot is not None
def test_dos_bot_reset(dos_bot):
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
# should reset the relevant items
dos_bot.reset_component_for_episode(episode=0)
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
dos_bot.set_original_state()
dos_bot.reset_component_for_episode(episode=1)
# should reset to the configured value
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
assert dos_bot.target_port is Port.HTTP
assert dos_bot.payload == "payload"
assert dos_bot.repeat is True
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
dos_bot_node: Computer = dos_bot.parent
assert dos_bot_node.operating_state is NodeOperatingState.ON
dos_bot_node.power_off()
for i in range(dos_bot_node.shut_down_duration + 1):
dos_bot_node.apply_timestep(timestep=i)
assert dos_bot_node.operating_state is NodeOperatingState.OFF
dos_bot._application_loop()
# assert not run
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
def test_dos_bot_not_configured(dos_bot):
dos_bot.target_ip_address = None
dos_bot.operating_state = ApplicationOperatingState.RUNNING
dos_bot._application_loop()
def test_dos_bot_perform_port_scan(dos_bot):
dos_bot._perform_port_scan(p_of_success=1)
assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN

View File

@@ -0,0 +1,50 @@
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.software import SoftwareHealthState
def test_scan(application):
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_visible == SoftwareHealthState.UNUSED
application.run()
assert application.operating_state == ApplicationOperatingState.RUNNING
assert application.health_state_visible == SoftwareHealthState.UNUSED
application.scan()
assert application.operating_state == ApplicationOperatingState.RUNNING
assert application.health_state_visible == SoftwareHealthState.GOOD
def test_run_application(application):
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_actual == SoftwareHealthState.UNUSED
application.run()
assert application.operating_state == ApplicationOperatingState.RUNNING
assert application.health_state_actual == SoftwareHealthState.GOOD
def test_close_application(application):
application.run()
assert application.operating_state == ApplicationOperatingState.RUNNING
assert application.health_state_actual == SoftwareHealthState.GOOD
application.close()
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_actual == SoftwareHealthState.GOOD
def test_application_describe_states(application):
assert application.operating_state == ApplicationOperatingState.CLOSED
assert application.health_state_actual == SoftwareHealthState.UNUSED
assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual")
application.run()
assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual")
application.set_health_state(SoftwareHealthState.COMPROMISED)
assert SoftwareHealthState.COMPROMISED.value == application.describe_state().get("health_state_actual")
application.patch()
assert SoftwareHealthState.PATCHING.value == application.describe_state().get("health_state_actual")

View File

@@ -1,5 +1,6 @@
from ipaddress import IPv4Address
from typing import Tuple, Union
from uuid import uuid4
import pytest
@@ -62,18 +63,26 @@ def test_disconnect_when_client_is_closed(database_client_on_computer):
def test_disconnect(database_client_on_computer):
"""Database client should set connected to False and remove the database server ip address."""
"""Database client should remove the connection."""
database_client, computer = database_client_on_computer
database_client.connected = True
database_client._connections[str(uuid4())] = {"item": True}
assert len(database_client.connections) == 1
assert database_client.operating_state is ApplicationOperatingState.RUNNING
assert database_client.server_ip_address is not None
database_client.disconnect()
assert database_client.connected is False
assert database_client.server_ip_address is None
assert len(database_client.connections) == 0
uuid = str(uuid4())
database_client._connections[uuid] = {"item": True}
assert len(database_client.connections) == 1
database_client.disconnect(connection_id=uuid)
assert len(database_client.connections) == 0
def test_query_when_client_is_closed(database_client_on_computer):
@@ -86,19 +95,6 @@ def test_query_when_client_is_closed(database_client_on_computer):
assert database_client.query(sql="test") is False
def test_query_failed_reattempt(database_client_on_computer):
"""Database client query should return False if the reattempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test", is_reattempt=True) is False
def test_query_fail_to_connect(database_client_on_computer):
"""Database client query should return False if the connect attempt fails."""
database_client, computer = database_client_on_computer

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
@@ -17,52 +19,174 @@ def test_scan(service):
def test_start_service(service):
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
service.start()
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_stop_service(service):
service.start()
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.GOOD
service.stop()
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_pause_and_resume_service(service):
assert service.operating_state == ServiceOperatingState.STOPPED
service.resume()
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
service.start()
assert service.health_state_actual == SoftwareHealthState.GOOD
service.pause()
assert service.operating_state == ServiceOperatingState.PAUSED
assert service.health_state_actual == SoftwareHealthState.GOOD
service.resume()
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_restart(service):
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
service.restart()
# Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
service.start()
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.GOOD
service.restart()
# Service is RUNNING. Restart should work
assert service.operating_state == ServiceOperatingState.RESTARTING
assert service.health_state_actual == SoftwareHealthState.GOOD
timestep = 0
while service.operating_state == ServiceOperatingState.RESTARTING:
service.apply_timestep(timestep)
assert service.health_state_actual == SoftwareHealthState.GOOD
timestep += 1
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_restart_compromised(service):
service.start()
assert service.health_state_actual == SoftwareHealthState.GOOD
# compromise the service
service.set_health_state(SoftwareHealthState.COMPROMISED)
service.restart()
assert service.operating_state == ServiceOperatingState.RESTARTING
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
"""
Service should be compromised even after reset.
Only way to remove compromised status is via patching.
"""
timestep = 0
while service.operating_state == ServiceOperatingState.RESTARTING:
service.apply_timestep(timestep)
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
timestep += 1
assert service.operating_state == ServiceOperatingState.RUNNING
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
def test_compromised_service_remains_compromised(service):
"""
Tests that a compromised service stays compromised.
The only way that the service can be uncompromised is by running patch.
"""
service.start()
assert service.health_state_actual == SoftwareHealthState.GOOD
service.set_health_state(SoftwareHealthState.COMPROMISED)
service.stop()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
service.start()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
service.disable()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
service.enable()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
service.pause()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
service.resume()
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
def test_service_patching(service):
service.start()
assert service.health_state_actual == SoftwareHealthState.GOOD
service.set_health_state(SoftwareHealthState.COMPROMISED)
service.patch()
assert service.health_state_actual == SoftwareHealthState.PATCHING
for i in range(service.patching_duration + 1):
service.apply_timestep(i)
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_enable_disable(service):
service.disable()
assert service.operating_state == ServiceOperatingState.DISABLED
assert service.health_state_actual == SoftwareHealthState.UNUSED
service.enable()
assert service.operating_state == ServiceOperatingState.STOPPED
assert service.health_state_actual == SoftwareHealthState.UNUSED
def test_overwhelm_service(service):
service.max_sessions = 2
service.start()
uuid = str(uuid4())
assert service.add_connection(connection_id=uuid) # should be true
assert service.health_state_actual == SoftwareHealthState.GOOD
assert not service.add_connection(connection_id=uuid) # fails because connection already exists
assert service.health_state_actual == SoftwareHealthState.GOOD
assert service.add_connection(connection_id=str(uuid4())) # succeed
assert service.health_state_actual == SoftwareHealthState.GOOD
assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
def test_create_and_remove_connections(service):
service.start()
uuid = str(uuid4())
assert service.add_connection(connection_id=uuid) # should be true
assert len(service.connections) == 1
assert service.health_state_actual is SoftwareHealthState.GOOD
assert service.remove_connection(connection_id=uuid) # should be true
assert len(service.connections) == 0
assert service.health_state_actual is SoftwareHealthState.GOOD

View File

@@ -0,0 +1,29 @@
from typing import Dict
import pytest
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.software import Software, SoftwareHealthState
class TestSoftware(Software):
def describe_state(self) -> Dict:
pass
@pytest.fixture(scope="function")
def software(file_system):
return TestSoftware(
name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
)
def test_software_creation(software):
assert software is not None
def test_software_set_health_state(software):
assert software.health_state_actual == SoftwareHealthState.UNUSED
software.set_health_state(SoftwareHealthState.GOOD)
assert software.health_state_actual == SoftwareHealthState.GOOD