Merge remote-tracking branch 'origin/dev' into feature/2137-refactor-request-api
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -156,4 +156,5 @@ benchmark/output
|
||||
# src/primaite/notebooks/scratch.ipynb
|
||||
src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
sandbox/
|
||||
sandbox.ipynb
|
||||
|
||||
30
CHANGELOG.md
30
CHANGELOG.md
@@ -6,6 +6,24 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Fixed a bug where ACL rules were not resetting on episode reset.
|
||||
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
|
||||
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
|
||||
- Fixed a bug where service health status was using the actual health state instead of the visible health state
|
||||
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
|
||||
- Fixed a bug preventing file actions from reaching their intended file
|
||||
- Made database patch correctly take 2 timesteps instead of being immediate
|
||||
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
|
||||
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
|
||||
- Fix a bug where ACLs were not showing up correctly in the observation space.
|
||||
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
|
||||
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
|
||||
- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config.
|
||||
- Fixed an issue where the data manipulation attack was triggered at episode start.
|
||||
- Fixed a bug where FTP STOR stored an additional copy on the client machine's filesystem
|
||||
- Fixed a bug where the red agent acted to early
|
||||
- Fixed the order of service health state
|
||||
- Fixed an issue where starting a node didn't start the services on it
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +56,18 @@ SessionManager.
|
||||
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
|
||||
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
|
||||
- NTP Services: `NTPClient` and `NTPServer`
|
||||
- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic.
|
||||
- **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required.
|
||||
- **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance.
|
||||
- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework.
|
||||
|
||||
### Changed
|
||||
- Integrated the RouteTable into the Routers frame processing.
|
||||
- Frames are now dropped when their TTL reaches 0
|
||||
- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts.
|
||||
- **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting.
|
||||
- **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios.
|
||||
|
||||
|
||||
### Removed
|
||||
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`
|
||||
|
||||
@@ -33,7 +33,7 @@ Currently, the PrimAITE wheel can only be installed from GitHub. This may change
|
||||
#### Windows (PowerShell)
|
||||
|
||||
**Prerequisites:**
|
||||
* Manual install of Python >= 3.8 < 3.11
|
||||
* Manual install of Python >= 3.8 < 3.12
|
||||
|
||||
**Install:**
|
||||
|
||||
@@ -56,7 +56,7 @@ primaite session
|
||||
#### Unix
|
||||
|
||||
**Prerequisites:**
|
||||
* Manual install of Python >= 3.8 < 3.11
|
||||
* Manual install of Python >= 3.8 < 3.12
|
||||
|
||||
``` bash
|
||||
sudo add-apt-repository ppa:deadsnakes/ppa
|
||||
@@ -82,6 +82,7 @@ primaite session
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Developer Install from Source
|
||||
To make your own changes to PrimAITE, perform the install from source (developer install)
|
||||
|
||||
@@ -138,3 +139,7 @@ make html
|
||||
cd docs
|
||||
.\make.bat html
|
||||
```
|
||||
|
||||
|
||||
## Example notebooks
|
||||
Check out the example notebooks to learn more about how PrimAITE works and how you can use it to train agents. They are automatically copied to your primaite installation directory when you run `primaite setup`.
|
||||
|
||||
@@ -13,7 +13,25 @@ This section allows selecting which training framework and algorithm to use, and
|
||||
|
||||
``io_settings``
|
||||
---------------
|
||||
This section configures how the ``PrimaiteSession`` saves data.
|
||||
This section configures how PrimAITE saves data during simulation and training.
|
||||
|
||||
**save_final_model**: Only used if training with PrimaiteSession, if true, the policy will be saved after the final training iteration.
|
||||
|
||||
**save_checkpoints**: Only used if training with PrimaiteSession, if true, the policy will be saved periodically during training.
|
||||
|
||||
**checkpoint_interval**: Only used if training with PrimaiteSession and if ``save_checkpoints`` is true. Defines how often to save the policy during training.
|
||||
|
||||
**save_logs**: *currently unused*.
|
||||
|
||||
**save_transactions**: *currently unused*.
|
||||
|
||||
**save_tensorboard_logs**: *currently unused*.
|
||||
|
||||
**save_step_metadata**: Whether to save the RL agents' action, environment state, and other data at every single step.
|
||||
|
||||
**save_pcap_logs**: Whether to save pcap files of all network traffic during the simulation.
|
||||
|
||||
**save_sys_logs**: Whether to save system logs from all nodes during the simulation.
|
||||
|
||||
``game``
|
||||
--------
|
||||
@@ -56,6 +74,10 @@ Description of configurable items:
|
||||
**agent_settings**:
|
||||
Settings passed to the agent during initialisation. These depend on the agent class.
|
||||
|
||||
Reinforcement learning agents use the ``ProxyAgent`` class, they accept these agent settings:
|
||||
|
||||
**flatten_obs**: If true, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to true if your agent does not support nested observation spaces.
|
||||
|
||||
``simulation``
|
||||
--------------
|
||||
In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents.
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.0.0b4dev
|
||||
3.0.0b6
|
||||
|
||||
@@ -14,6 +14,8 @@ io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
|
||||
|
||||
game:
|
||||
@@ -29,7 +31,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -110,10 +112,8 @@ agents:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: DatabaseClient
|
||||
- service_name: web_server_web_service
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
@@ -302,63 +302,63 @@ agents:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22:
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
@@ -507,6 +507,24 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -526,7 +544,7 @@ agents:
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
|
||||
@@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb: str = "delete"
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["do_nothing"]
|
||||
# return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid]
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
@@ -443,27 +453,33 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0, 1]:
|
||||
if source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_ip_id == 1:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
if source_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0, 1):
|
||||
if source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
if dest_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
@@ -943,13 +959,22 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
node_ref = entry["node_ref"]
|
||||
nic_num = entry["nic_num"]
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_ref)
|
||||
ip_address = node_obj.ethernet_port[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
obj = cls(
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=None,
|
||||
ip_address_list=ip_address_list or None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
@@ -46,3 +45,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
super().reset_agent_for_episode()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
@@ -44,6 +44,8 @@ class AgentSettings(BaseModel):
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
@@ -134,6 +136,10 @@ class AbstractAgent(ABC):
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Agent reset logic should go here."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
@@ -166,6 +172,7 @@ class ProxyAgent(AbstractAgent):
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
@@ -174,6 +181,7 @@ class ProxyAgent(AbstractAgent):
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -78,7 +79,7 @@ class FileObservation(AbstractObservation):
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["health_status"]}
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -204,12 +205,15 @@ class LinkObservation(AbstractObservation):
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": utilisation_category}}
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -554,7 +558,7 @@ class NodeObservation(AbstractObservation):
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
@@ -644,10 +648,13 @@ class AclObservation(AbstractObservation):
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
@@ -656,15 +663,26 @@ class AclObservation(AbstractObservation):
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -110,10 +110,17 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.info(
|
||||
f"Could not calculate {self.__class__} reward because "
|
||||
"simulation state did not contain enough information."
|
||||
)
|
||||
return 0.0
|
||||
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == "corrupted":
|
||||
if health_status == 2:
|
||||
return -1
|
||||
elif health_status == "good":
|
||||
elif health_status == 1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
@@ -117,7 +115,7 @@ class PrimaiteGame:
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
agent_actions = self.apply_agent_actions() # noqa
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -135,12 +133,15 @@ class PrimaiteGame:
|
||||
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
agent_actions = {}
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
return agent_actions
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
@@ -164,6 +165,7 @@ class PrimaiteGame:
|
||||
self.simulation.reset_component_for_episode(episode=self.episode_counter)
|
||||
for agent in self.agents:
|
||||
agent.reward_function.total_reward = 0.0
|
||||
agent.reset_agent_for_episode()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
@@ -228,31 +230,7 @@ class PrimaiteGame:
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
elif n_type == "router":
|
||||
new_node = Router(
|
||||
hostname=node_cfg["hostname"],
|
||||
num_ports=node_cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
# new_node.enable_port(port_num)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
new_node = Router.from_config(node_cfg)
|
||||
else:
|
||||
_LOGGER.warning(f"invalid node type {n_type} in config")
|
||||
if "services" in node_cfg:
|
||||
@@ -389,6 +367,7 @@ class PrimaiteGame:
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
game.rl_agents.append(new_agent)
|
||||
|
||||
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
@@ -39,6 +39,15 @@
|
||||
"#### Create a Ray algorithm and pass it our config."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(cfg['agents'][2]['agent_settings'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -76,6 +85,13 @@
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
@@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
@@ -81,13 +81,19 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
if self.agent.flatten_obs:
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
else:
|
||||
return self.agent.observation_manager.space
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if not self.agent.flatten_obs:
|
||||
return self.agent.observation_manager.current_observation
|
||||
else:
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
@@ -166,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -180,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {}
|
||||
infos = {"agent_actions": agent_actions}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
|
||||
@@ -24,9 +24,13 @@ class SessionIOSettings(BaseModel):
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
"""Whether to save system logs."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
@@ -39,9 +43,10 @@ class SessionIO:
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
@@ -54,7 +54,7 @@ class PrimaiteSession:
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
self.io_manager: Optional["SessionIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
@@ -101,9 +101,9 @@ class PrimaiteSession:
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"game": game})
|
||||
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
|
||||
|
||||
@@ -44,3 +44,12 @@ def run(overwrite_existing: bool = True) -> None:
|
||||
print(dst_fp)
|
||||
shutil.copy2(src_fp, dst_fp)
|
||||
_LOGGER.info(f"Reset example notebook: {dst_fp}")
|
||||
|
||||
for src_fp in primaite_root.glob("notebooks/_package_data/*"):
|
||||
dst_fp = example_notebooks_user_dir / "_package_data" / src_fp.name
|
||||
if should_copy_file(src_fp, dst_fp, overwrite_existing):
|
||||
if not Path.exists(example_notebooks_user_dir / "_package_data/"):
|
||||
Path.mkdir(example_notebooks_user_dir / "_package_data/")
|
||||
print(dst_fp)
|
||||
shutil.copy2(src_fp, dst_fp)
|
||||
_LOGGER.info(f"Copied notebook resource to: {dst_fp}")
|
||||
|
||||
@@ -7,11 +7,13 @@ from primaite import _PRIMAITE_ROOT
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
|
||||
class __SimOutput:
|
||||
class _SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
@@ -23,4 +25,4 @@ class __SimOutput:
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
SIM_OUTPUT = __SimOutput()
|
||||
SIM_OUTPUT = _SimOutput()
|
||||
|
||||
@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
|
||||
"List containing all the folders in the file system."
|
||||
deleted_folders: Dict[str, Folder] = {}
|
||||
"List containing all the folders that have been deleted."
|
||||
_folders_by_name: Dict[str, Folder] = {}
|
||||
sys_log: SysLog
|
||||
"Instance of SysLog used to create system logs."
|
||||
sim_root: Path
|
||||
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
|
||||
folder = Folder(name=folder_name, sys_log=self.sys_log)
|
||||
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
self._folder_request_manager.add_request(
|
||||
name=folder.name, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
|
||||
:param folder_name: The folder name.
|
||||
:return: The matching Folder.
|
||||
"""
|
||||
return self._folders_by_name.get(folder_name)
|
||||
for folder in self.folders.values():
|
||||
if folder.name == folder_name:
|
||||
return folder
|
||||
return None
|
||||
|
||||
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
|
||||
"""
|
||||
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name:
|
||||
# check if file with name already exists
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
# If not then create it
|
||||
if not folder:
|
||||
folder = self.create_folder(folder_name)
|
||||
else:
|
||||
# Use root folder if folder_name not supplied
|
||||
folder = self._folders_by_name["root"]
|
||||
folder = self.get_folder("root")
|
||||
|
||||
# Create the file and add it to the folder
|
||||
file = File(
|
||||
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
|
||||
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
if folder.deleted:
|
||||
self.deleted_folders.pop(folder.uuid)
|
||||
|
||||
@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
files: Dict[str, File] = {}
|
||||
"Files stored in the folder."
|
||||
_files_by_name: Dict[str, File] = {}
|
||||
"Files by their name as <file name>.<file type>."
|
||||
deleted_files: Dict[str, File] = {}
|
||||
"Files that have been deleted."
|
||||
|
||||
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
@@ -105,7 +101,7 @@ class Folder(FileSystemItemABC):
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(func=lambda request, context: self._file_request_manager),
|
||||
request_type=RequestType(func=self._file_request_manager),
|
||||
)
|
||||
return rm
|
||||
|
||||
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
|
||||
:return: The matching File.
|
||||
"""
|
||||
# TODO: Increment read count?
|
||||
return self._files_by_name.get(file_name)
|
||||
for file in self.files.values():
|
||||
if file.name == file_name:
|
||||
return file
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
|
||||
"""
|
||||
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
|
||||
raise Exception(f"Invalid file: {file}")
|
||||
|
||||
# check if file with id or name already exists in folder
|
||||
if (force is not True) and file.name in self._files_by_name:
|
||||
if self.get_file(file.name) is not None and not force:
|
||||
raise Exception(f"File with name {file.name} already exists in folder")
|
||||
|
||||
if (force is not True) and file.uuid in self.files:
|
||||
if (file.uuid in self.files) and not force:
|
||||
raise Exception(f"File with uuid {file.uuid} already exists in folder")
|
||||
|
||||
# add to list
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
|
||||
file.folder = self
|
||||
|
||||
@@ -275,11 +273,9 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
if self.files.get(file.uuid):
|
||||
self.files.pop(file.uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
self.deleted_files[file.uuid] = file
|
||||
file.delete()
|
||||
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
|
||||
self._file_request_manager.remove_request(file.uuid)
|
||||
else:
|
||||
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
|
||||
|
||||
@@ -300,7 +296,6 @@ class Folder(FileSystemItemABC):
|
||||
self.deleted_files[file_id] = file
|
||||
|
||||
self.files = {}
|
||||
self._files_by_name = {}
|
||||
|
||||
def restore_file(self, file_uuid: str):
|
||||
"""
|
||||
@@ -316,7 +311,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file_uuid)
|
||||
|
||||
148
src/primaite/simulator/network/creation.py
Normal file
148
src/primaite/simulator/network/creation.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
Each switch is assumed to have one port reserved for connecting to a router, reducing the effective
|
||||
number of ports available for PCs. The function calculates the total number of switches needed
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
Example:
|
||||
>>> num_of_switches_required(5)
|
||||
1
|
||||
>>> num_of_switches_required(24,24)
|
||||
2
|
||||
>>> num_of_switches_required(48,24)
|
||||
3
|
||||
>>> num_of_switches_required(25,10)
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
|
||||
|
||||
def create_office_lan(
|
||||
lan_name: str,
|
||||
subnet_base: int,
|
||||
pcs_ip_block_start: int,
|
||||
num_pcs: int,
|
||||
network: Optional[Network] = None,
|
||||
include_router: bool = True,
|
||||
) -> Network:
|
||||
"""
|
||||
Creates a 2-Tier or 3-Tier office local area network (LAN).
|
||||
|
||||
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
|
||||
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
|
||||
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
|
||||
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
|
||||
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
|
||||
|
||||
|
||||
:param str lan_name: The name to be assigned to the LAN.
|
||||
:param int subnet_base: The subnet base number to be used in the IP addresses.
|
||||
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
|
||||
:param int num_pcs: The number of PCs to be added to the LAN.
|
||||
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
|
||||
created.
|
||||
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
|
||||
:return: The network object with the LAN components added.
|
||||
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
|
||||
"""
|
||||
# Initialise the network if not provided
|
||||
if not network:
|
||||
network = Network()
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
# Create a core switch if more than one edge switch is needed
|
||||
if num_of_switches > 1:
|
||||
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
|
||||
core_switch.power_on()
|
||||
network.add_node(core_switch)
|
||||
core_switch_port = 1
|
||||
|
||||
# Initialise the default gateway to None
|
||||
default_gateway = None
|
||||
|
||||
# Optionally include a router in the LAN
|
||||
if include_router:
|
||||
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
|
||||
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
network.add_node(router)
|
||||
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
|
||||
router.enable_port(1)
|
||||
|
||||
# Initialise the first edge switch and connect to the router or core switch
|
||||
switch_port = 0
|
||||
switch_n = 1
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
if num_of_switches > 1:
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
# Connect the new switch to the router or core switch
|
||||
if num_of_switches > 1:
|
||||
core_switch_port += 1
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Create and add a PC to the network
|
||||
pc = Computer(
|
||||
hostname=f"pc_{i}_{lan_name}",
|
||||
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway=default_gateway,
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc.power_on()
|
||||
network.add_node(pc)
|
||||
|
||||
# Connect the PC to the switch
|
||||
switch_port += 1
|
||||
network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1])
|
||||
switch.switch_ports[switch_port].enable()
|
||||
|
||||
return network
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import secrets
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -274,18 +274,40 @@ class NIC(SimComponent):
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive a network frame from the connected link if the NIC is enabled.
|
||||
Receive a network frame from the connected link, processing it if the NIC is enabled.
|
||||
|
||||
The Frame is passed to the Node.
|
||||
This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks
|
||||
if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the
|
||||
connected node. The method also handles the discarding of frames with TTL expired and logs this event.
|
||||
|
||||
:param frame: The network frame being received.
|
||||
The frame's reception is based on various conditions:
|
||||
- If the NIC is disabled, the frame is not processed.
|
||||
- If the TTL of the frame reaches zero after decrement, it is discarded and logged.
|
||||
- If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted.
|
||||
- All other frames are dropped and logged or printed to the console.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
accept_frame = False
|
||||
|
||||
# Check if it's a broadcast:
|
||||
if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}:
|
||||
accept_frame = True
|
||||
else:
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address:
|
||||
accept_frame = True
|
||||
|
||||
if accept_frame:
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
@@ -436,6 +458,9 @@ class SwitchPort(SimComponent):
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
self.pcap.capture(frame)
|
||||
connected_node: Node = self._connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
@@ -671,17 +696,30 @@ class ARPCache:
|
||||
"""Clear the entire ARP cache, removing all stored entries."""
|
||||
self.arp.clear()
|
||||
|
||||
def send_arp_request(self, target_ip_address: Union[IPv4Address, str]):
|
||||
def send_arp_request(
|
||||
self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None
|
||||
):
|
||||
"""
|
||||
Perform a standard ARP request for a given target IP address.
|
||||
|
||||
Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP
|
||||
address.
|
||||
address. This method can be configured to ignore specific networks when sending out ARP requests,
|
||||
which is useful in environments where certain addresses should not be queried.
|
||||
|
||||
:param target_ip_address: The target IP address to send an ARP request for.
|
||||
:param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP
|
||||
request broadcast. Each address in this list indicates a network which will not be queried during the ARP
|
||||
request process. This is particularly useful in complex network environments where traffic should be
|
||||
minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being
|
||||
sent back to their source.
|
||||
"""
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled:
|
||||
use_nic = True
|
||||
if ignore_networks:
|
||||
for ipv4 in ignore_networks:
|
||||
if ipv4 in nic.ip_network:
|
||||
use_nic = False
|
||||
if nic.enabled and use_nic:
|
||||
self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}")
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
@@ -806,7 +844,6 @@ class ICMP:
|
||||
self.arp.send_arp_request(frame.ip.src_ip_address)
|
||||
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
|
||||
return
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
# Network Layer
|
||||
ip_packet = IPPacket(
|
||||
@@ -821,9 +858,7 @@ class ICMP:
|
||||
sequence=frame.icmp.sequence + 1,
|
||||
)
|
||||
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
|
||||
frame = Frame(
|
||||
ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload
|
||||
)
|
||||
frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload)
|
||||
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
|
||||
|
||||
src_nic.send_frame(frame)
|
||||
@@ -1275,8 +1310,8 @@ class Node(SimComponent):
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self._start_up_actions()
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Turned on")
|
||||
for nic in self.nics.values():
|
||||
if nic._connected_link:
|
||||
@@ -1450,7 +1485,7 @@ class Node(SimComponent):
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.info(f"Added service {service.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
|
||||
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
@@ -1485,7 +1520,7 @@ class Node(SimComponent):
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.info(f"Added application {application.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
class ACLAction(Enum):
|
||||
"""Enum for defining the ACL action types."""
|
||||
|
||||
DENY = 0
|
||||
PERMIT = 1
|
||||
DENY = 2
|
||||
|
||||
|
||||
class ACLRule(SimComponent):
|
||||
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24
|
||||
_default_config: Dict[int, dict] = {}
|
||||
"""Config dict describing how the ACL list should look at episode start"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
if not kwargs.get("implicit_action"):
|
||||
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
for i, rule in enumerate(self._acl):
|
||||
if not rule:
|
||||
continue
|
||||
self._default_config[i] = {"action": rule.action.name}
|
||||
if rule.src_ip_address:
|
||||
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
|
||||
if rule.dst_ip_address:
|
||||
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
|
||||
if rule.src_port:
|
||||
self._default_config[i]["src_port"] = rule.src_port.name
|
||||
if rule.dst_port:
|
||||
self._default_config[i]["dst_port"] = rule.dst_port.name
|
||||
if rule.protocol:
|
||||
self._default_config[i]["protocol"] = rule.protocol.name
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -129,9 +162,9 @@ class AccessControlList(SimComponent):
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
None if request[3] == "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
None if request[5] == "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
@@ -333,11 +366,10 @@ class RouteEntry(SimComponent):
|
||||
"""
|
||||
Represents a single entry in a routing table.
|
||||
|
||||
Attributes:
|
||||
address (IPv4Address): The destination IP address or network address.
|
||||
subnet_mask (IPv4Address): The subnet mask for the network.
|
||||
next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded.
|
||||
metric (int): The cost metric for this route. Default is 0.0.
|
||||
:ivar address: The destination IP address or network address.
|
||||
:ivar subnet_mask: The subnet mask for the network.
|
||||
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
|
||||
:ivar metric: The cost metric for this route. Default is 0.0.
|
||||
|
||||
Example:
|
||||
>>> entry = RouteEntry(
|
||||
@@ -357,12 +389,6 @@ class RouteEntry(SimComponent):
|
||||
metric: float = 0.0
|
||||
"The cost metric for this route. Default is 0.0."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key in {"address", "subnet_mask", "next_hop_ip_address"}:
|
||||
if not isinstance(kwargs[key], IPv4Address):
|
||||
kwargs[key] = IPv4Address(kwargs[key])
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
|
||||
@@ -397,10 +423,10 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
|
||||
routes: List[RouteEntry] = []
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
self._original_state["routes_orig"] = self.routes
|
||||
@@ -442,12 +468,35 @@ class RouteTable(SimComponent):
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
|
||||
"""
|
||||
Sets the next-hop IP address for the default route in a routing table.
|
||||
|
||||
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
|
||||
the method creates a new default route with the specified next-hop IP address. If a default route already
|
||||
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
|
||||
IP address, the method logs this action.
|
||||
|
||||
:param ip_address: The next-hop IP address to be set for the default route.
|
||||
"""
|
||||
if not self.default_route:
|
||||
self.default_route = RouteEntry(
|
||||
ip_address=IPv4Address("0.0.0.0"),
|
||||
subnet_mask=IPv4Address("0.0.0.0"),
|
||||
next_hop_ip_address=ip_address,
|
||||
)
|
||||
else:
|
||||
self.default_route.next_hop_ip_address = ip_address
|
||||
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
|
||||
|
||||
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
|
||||
"""
|
||||
Find the best route for a given destination IP.
|
||||
|
||||
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
|
||||
|
||||
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
|
||||
|
||||
:param destination_ip: The destination IP to find the route for.
|
||||
:return: The best matching RouteEntry, or None if no route matches.
|
||||
"""
|
||||
@@ -467,6 +516,9 @@ class RouteTable(SimComponent):
|
||||
longest_prefix = prefix_len
|
||||
lowest_metric = route.metric
|
||||
|
||||
if not best_route and self.default_route:
|
||||
best_route = self.default_route
|
||||
|
||||
return best_route
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -498,12 +550,26 @@ class RouterARPCache(ARPCache):
|
||||
super().__init__(sys_log)
|
||||
self.router: Router = router
|
||||
|
||||
def process_arp_packet(self, from_nic: NIC, frame: Frame):
|
||||
def process_arp_packet(
|
||||
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Overridden method to process a received ARP packet in a router-specific way.
|
||||
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
|
||||
|
||||
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
|
||||
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
|
||||
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
|
||||
handling packet TTL (Time To Live).
|
||||
|
||||
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
|
||||
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
|
||||
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
|
||||
to find the best route and reattempts ARP request processing if needed.
|
||||
|
||||
:param from_nic: The NIC that received the ARP packet.
|
||||
:param frame: The original ARP frame.
|
||||
:param frame: The frame containing the ARP packet.
|
||||
:param route_table: The routing table of the router.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
|
||||
"""
|
||||
arp_packet = frame.arp
|
||||
|
||||
@@ -531,7 +597,11 @@ class RouterARPCache(ARPCache):
|
||||
)
|
||||
arp_packet.sender_mac_addr = nic.mac_address
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
|
||||
# ARP Request
|
||||
self.sys_log.info(
|
||||
@@ -542,16 +612,32 @@ class RouterARPCache(ARPCache):
|
||||
self.add_arp_cache_entry(
|
||||
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
|
||||
)
|
||||
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_packet, from_nic)
|
||||
|
||||
# If the target IP matches one of the router's NICs
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
if arp_packet.target_ip_address in nic.ip_network:
|
||||
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
# Check Route Table
|
||||
route = route_table.find_best_route(arp_packet.target_ip_address)
|
||||
if route:
|
||||
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
|
||||
if not nic:
|
||||
if not is_reattempt:
|
||||
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
|
||||
else:
|
||||
self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found")
|
||||
return
|
||||
else:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
|
||||
class RouterICMP(ICMP):
|
||||
"""
|
||||
@@ -622,7 +708,7 @@ class RouterICMP(ICMP):
|
||||
return
|
||||
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
|
||||
for nic in self.router.nics.values():
|
||||
@@ -642,7 +728,48 @@ class RouterICMP(ICMP):
|
||||
|
||||
return
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
|
||||
class RouterNIC(NIC):
|
||||
"""
|
||||
A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality.
|
||||
|
||||
This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed
|
||||
to handle network frames in a manner specific to routers, allowing them to efficiently process
|
||||
and route network traffic.
|
||||
"""
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive and process a network frame from the connected link, provided the NIC is enabled.
|
||||
|
||||
This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL
|
||||
expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for
|
||||
this NIC's MAC address or is a broadcast frame.
|
||||
|
||||
Key Differences from Standard NIC:
|
||||
- Does not perform Layer 3 (IP-based) broadcast checks.
|
||||
- Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the connected node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.mac_address}/{self.ip_address}"
|
||||
|
||||
|
||||
class Router(Node):
|
||||
@@ -655,7 +782,7 @@ class Router(Node):
|
||||
"""
|
||||
|
||||
num_ports: int
|
||||
ethernet_ports: Dict[int, NIC] = {}
|
||||
ethernet_ports: Dict[int, RouterNIC] = {}
|
||||
acl: AccessControlList
|
||||
route_table: RouteTable
|
||||
arp: RouterARPCache
|
||||
@@ -674,7 +801,7 @@ class Router(Node):
|
||||
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
for i in range(1, self.num_ports + 1):
|
||||
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(nic)
|
||||
self.ethernet_ports[i] = nic
|
||||
|
||||
@@ -725,13 +852,13 @@ class Router(Node):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
state["num_ports"] = self.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
Route a given frame from a source NIC to its destination.
|
||||
Process a Frame.
|
||||
|
||||
:param frame: The frame to be routed.
|
||||
:param from_nic: The source network interface.
|
||||
@@ -746,25 +873,57 @@ class Router(Node):
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address)
|
||||
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
self.arp.send_arp_request(
|
||||
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
|
||||
)
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
# TODO: Add sys_log here
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
if frame.ip.dst_ip_address in nic.ip_network:
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
else:
|
||||
pass
|
||||
# TODO: Deal with routing from route tables
|
||||
self._route_frame(frame, from_nic)
|
||||
|
||||
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
|
||||
if route:
|
||||
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
|
||||
if re_attempt and not nic:
|
||||
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
"""
|
||||
@@ -773,7 +932,7 @@ class Router(Node):
|
||||
:param frame: The incoming frame.
|
||||
:param from_nic: The network interface where the frame is coming from.
|
||||
"""
|
||||
route_frame = False
|
||||
process_frame = False
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
@@ -805,12 +964,12 @@ class Router(Node):
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
else:
|
||||
if src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
|
||||
else:
|
||||
# All other traffic
|
||||
route_frame = True
|
||||
if route_frame:
|
||||
self.route_frame(frame, from_nic)
|
||||
process_frame = True
|
||||
if process_frame:
|
||||
self.process_frame(frame, from_nic)
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
@@ -873,3 +1032,63 @@ class Router(Node):
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
new = Router(
|
||||
hostname=cfg["hostname"],
|
||||
num_ports=cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
new.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
return new
|
||||
|
||||
@@ -90,12 +90,12 @@ class Switch(Node):
|
||||
self._add_mac_table_entry(src_mac, incoming_port)
|
||||
|
||||
outgoing_port = self.mac_address_table.get(dst_mac)
|
||||
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
|
||||
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
|
||||
outgoing_port.send_frame(frame)
|
||||
else:
|
||||
# If the destination MAC is not in the table, flood to all ports except incoming
|
||||
for port in self.switch_ports.values():
|
||||
if port != incoming_port:
|
||||
if port.enabled and port != incoming_port:
|
||||
port.send_frame(frame)
|
||||
|
||||
def disconnect_link_from_port(self, link: Link, port_number: int):
|
||||
|
||||
@@ -38,9 +38,6 @@ class Application(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
@@ -95,6 +92,9 @@ class Application(IOSoftware):
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Running Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def _application_loop(self):
|
||||
"""The main application loop."""
|
||||
|
||||
@@ -72,7 +72,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -83,7 +83,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
payload: Optional[str] = None,
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
data_manipulation_p_of_success: float = 0.1,
|
||||
repeat: bool = False,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the DataManipulatorBot to communicate with a DatabaseService.
|
||||
@@ -168,6 +168,12 @@ class DataManipulationBot(DatabaseClient):
|
||||
Calls the parent classes execute method before starting the application loop.
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
@@ -198,4 +204,4 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
pass
|
||||
|
||||
@@ -41,6 +41,9 @@ class PacketCapture:
|
||||
|
||||
def setup_logger(self):
|
||||
"""Set up the logger configuration."""
|
||||
if not SIM_OUTPUT.save_pcap_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
@@ -88,5 +91,6 @@ class PacketCapture:
|
||||
|
||||
:param frame: The PCAP frame to capture.
|
||||
"""
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -141,41 +141,76 @@ class SessionManager:
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
"""
|
||||
Receive a payload from the SoftwareManager.
|
||||
Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission.
|
||||
|
||||
If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``.
|
||||
This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an
|
||||
IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP.
|
||||
A new session is established if `session_id` is not provided, and an existing session is used otherwise.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
:param dst_ip_address: The destination IP address or network for broadcast. Optional.
|
||||
:param dst_port: The destination port for the TCP packet. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False.
|
||||
:return: The outcome of sending the frame, or None if sending was unsuccessful.
|
||||
"""
|
||||
is_broadcast = False
|
||||
outbound_nic = None
|
||||
dst_mac_address = None
|
||||
|
||||
# Use session details if session_id is provided
|
||||
if session_id:
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
dst_ip_address = session.with_ip_address
|
||||
dst_port = session.dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
# Determine if the payload is for broadcast or unicast
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
# Handle broadcast transmission
|
||||
if isinstance(dst_ip_address, IPv4Network):
|
||||
is_broadcast = True
|
||||
dst_ip_address = dst_ip_address.broadcast_address
|
||||
if dst_ip_address:
|
||||
# Find a suitable NIC for the broadcast
|
||||
for nic in self.arp_cache.nics.values():
|
||||
if dst_ip_address in nic.ip_network and nic.enabled:
|
||||
dst_mac_address = "ff:ff:ff:ff:ff:ff"
|
||||
outbound_nic = nic
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Resolve MAC address for unicast transmission
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
# Resolve outbound NIC for unicast transmission
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
|
||||
# If MAC address not found, initiate ARP request
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
# Reattempt payload transmission after ARP request
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
# Return None if reattempt fails
|
||||
return
|
||||
|
||||
# Check if outbound NIC and destination MAC address are resolved
|
||||
if not outbound_nic or not dst_mac_address:
|
||||
return False
|
||||
|
||||
# Construct the frame for transmission
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
@@ -189,15 +224,17 @@ class SessionManager:
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
# Manage session for unicast transmission
|
||||
if not (is_broadcast and session_id):
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
# Create a new session if it doesn't exist
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
|
||||
# Send the frame through the NIC
|
||||
return outbound_nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -130,20 +130,28 @@ class SoftwareManager:
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_port=dest_port,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
|
||||
@@ -41,6 +41,9 @@ class SysLog:
|
||||
The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out
|
||||
JSON-like messages.
|
||||
"""
|
||||
if not SIM_OUTPUT.save_sys_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
@@ -91,7 +94,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.debug(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
|
||||
def info(self, msg: str):
|
||||
"""
|
||||
@@ -99,7 +103,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.info(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
|
||||
def warning(self, msg: str):
|
||||
"""
|
||||
@@ -107,7 +112,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.warning(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
|
||||
def error(self, msg: str):
|
||||
"""
|
||||
@@ -115,7 +121,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.error(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
|
||||
def critical(self, msg: str):
|
||||
"""
|
||||
@@ -123,4 +130,5 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.critical(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -22,7 +24,7 @@ class DatabaseService(Service):
|
||||
|
||||
password: Optional[str] = None
|
||||
|
||||
backup_server: IPv4Address = None
|
||||
backup_server_ip: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
|
||||
latest_backup_directory: str = None
|
||||
@@ -36,7 +38,6 @@ class DatabaseService(Service):
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
|
||||
def set_original_state(self):
|
||||
@@ -45,8 +46,8 @@ class DatabaseService(Service):
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"_connections",
|
||||
"backup_server",
|
||||
"connections",
|
||||
"backup_server_ip",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
}
|
||||
@@ -64,7 +65,7 @@ class DatabaseService(Service):
|
||||
|
||||
:param: backup_server_ip: The IP address of the backup server
|
||||
"""
|
||||
self.backup_server = backup_server
|
||||
self.backup_server_ip = backup_server
|
||||
|
||||
def backup_database(self) -> bool:
|
||||
"""Create a backup of the database to the configured backup server."""
|
||||
@@ -73,7 +74,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# check if the backup server was configured
|
||||
if self.backup_server is None:
|
||||
if self.backup_server_ip is None:
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
@@ -81,10 +82,14 @@ class DatabaseService(Service):
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# send backup copy of database file to FTP server
|
||||
if not self.db_file:
|
||||
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
|
||||
return False
|
||||
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
src_file_name=self.db_file.name,
|
||||
src_folder_name="database",
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
)
|
||||
@@ -110,7 +115,7 @@ class DatabaseService(Service):
|
||||
src_file_name="database.db",
|
||||
dest_folder_name="downloads",
|
||||
dest_file_name="database.db",
|
||||
dest_ip_address=self.backup_server,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
)
|
||||
|
||||
if not response:
|
||||
@@ -118,13 +123,10 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
|
||||
|
||||
if self._db_file is None:
|
||||
if self.db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
@@ -134,12 +136,30 @@ class DatabaseService(Service):
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def db_file(self) -> File:
|
||||
"""Returns the database file."""
|
||||
return self.file_system.get_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def folder(self) -> Folder:
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
:param connection_id: A unique identifier for the connection
|
||||
:type connection_id: str
|
||||
:param password: Supplied password. It must match self.password for connection success, defaults to None
|
||||
:type password: Optional[str], optional
|
||||
:return: Response to connection request containing success info.
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
@@ -184,7 +204,7 @@ class DatabaseService(Service):
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
@@ -195,17 +215,8 @@ class DatabaseService(Service):
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
@@ -265,3 +276,19 @@ class DatabaseService(Service):
|
||||
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
return payload["status_code"] == 200
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a single timestep of simulation dynamics to this service.
|
||||
|
||||
Here at the first step, the database backup is created, in addition to normal service update logic.
|
||||
"""
|
||||
if timestep == 1:
|
||||
self.backup_database()
|
||||
return super().apply_timestep(timestep)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Perform a database restore when the patching countdown is finished."""
|
||||
super()._update_patch_status()
|
||||
if self._patching_countdown is None:
|
||||
self.restore_backup()
|
||||
|
||||
@@ -89,6 +89,7 @@ class FTPClient(FTPServiceABC):
|
||||
f"{self.name}: Successfully connected to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
|
||||
)
|
||||
self.add_connection(connection_id="server_connection", session_id=session_id)
|
||||
return True
|
||||
else:
|
||||
if is_reattempt:
|
||||
|
||||
@@ -99,5 +99,5 @@ class FTPServer(FTPServiceABC):
|
||||
if payload.status_code is not None:
|
||||
return False
|
||||
|
||||
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import shutil
|
||||
from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -16,6 +16,10 @@ class FTPServiceABC(Service, ABC):
|
||||
Contains shared methods between both classes.
|
||||
"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Returns a Dict of the FTPService state."""
|
||||
return super().describe_state()
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
@@ -52,10 +56,12 @@ class FTPServiceABC(Service, ABC):
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
file_size = payload.ftp_command_args["file_size"]
|
||||
real_file_path = payload.ftp_command_args.get("real_file_path")
|
||||
health_status = payload.ftp_command_args["health_status"]
|
||||
is_real = real_file_path is not None
|
||||
file = self.file_system.create_file(
|
||||
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
|
||||
)
|
||||
file.health_status = health_status
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
|
||||
f"{payload.ftp_command_args['dest_file_name']}"
|
||||
@@ -110,6 +116,7 @@ class FTPServiceABC(Service, ABC):
|
||||
"dest_file_name": dest_file_name,
|
||||
"file_size": file.sim_size,
|
||||
"real_file_path": file.sim_path if file.real else None,
|
||||
"health_status": file.health_status,
|
||||
},
|
||||
packet_payload_size=file.sim_size,
|
||||
status_code=FTPStatusCode.OK if is_response else None,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -43,9 +44,6 @@ class Service(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -98,6 +96,7 @@ class Service(IOSoftware):
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -118,7 +117,6 @@ class Service(IOSoftware):
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
@@ -129,42 +127,39 @@ class Service(IOSoftware):
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.restart_countdown = self.restart_duration
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
@@ -181,5 +176,4 @@ class Service(IOSoftware):
|
||||
if self.restart_countdown <= 0:
|
||||
_LOGGER.debug(f"Restarting finished for service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.restart_countdown -= 1
|
||||
|
||||
@@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -123,7 +124,10 @@ class WebServer(Service):
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception:
|
||||
|
||||
@@ -2,8 +2,8 @@ import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -38,12 +38,12 @@ class SoftwareHealthState(Enum):
|
||||
"Unused state."
|
||||
GOOD = 1
|
||||
"The software is in a good and healthy condition."
|
||||
COMPROMISED = 2
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 3
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
PATCHING = 4
|
||||
PATCHING = 2
|
||||
"The software is undergoing patching or updates."
|
||||
COMPROMISED = 3
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 4
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
|
||||
|
||||
class SoftwareCriticality(Enum):
|
||||
@@ -71,9 +71,9 @@ class Software(SimComponent):
|
||||
|
||||
name: str
|
||||
"The name of the software."
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The actual health state of the software."
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
@@ -195,8 +195,9 @@ class Software(SimComponent):
|
||||
|
||||
def patch(self) -> None:
|
||||
"""Perform a patch on the software."""
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
@@ -282,7 +283,7 @@ class IOSoftware(Software):
|
||||
|
||||
Returns true if the software can perform actions.
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state is NodeOperatingState.OFF:
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
|
||||
return False
|
||||
return True
|
||||
@@ -303,13 +304,13 @@ class IOSoftware(Software):
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
@@ -350,19 +351,22 @@ class IOSoftware(Software):
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network. It delegates the actual sending process to the SoftwareManager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
@@ -19,7 +19,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -489,6 +489,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -23,7 +23,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -493,6 +493,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -29,7 +29,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -500,6 +500,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -929,6 +946,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -27,7 +27,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -500,6 +500,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -23,7 +23,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -501,6 +501,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -40,6 +40,9 @@ from primaite.simulator.network.hardware.base import Link, Node
|
||||
class TestService(Service):
|
||||
"""Test Service class"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return super().describe_state()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "TestService"
|
||||
kwargs["port"] = Port.HTTP
|
||||
@@ -60,7 +63,7 @@ class TestApplication(Application):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
return super().describe_state()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -167,7 +170,7 @@ def example_network() -> Network:
|
||||
-------------- --------------
|
||||
| client_1 |----- ----| server_1 |
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
------| switch_1 |------| router_1 |------| switch_2 |------
|
||||
------| switch_2 |------| router_1 |------| switch_1 |------
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
| client_2 |---- ----| server_2 |
|
||||
-------------- --------------
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_data_manipulation(uc2_network):
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
# Now we run the DataManipulationBot
|
||||
db_manipulation_bot.run()
|
||||
db_manipulation_bot.attack()
|
||||
|
||||
# Now check that the DB client on the web_server cannot query the users table on the database
|
||||
assert not db_client.query("SELECT")
|
||||
|
||||
180
tests/integration_tests/network/test_broadcast.py
Normal file
180
tests/integration_tests/network/test_broadcast.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class BroadcastService(Service):
|
||||
"""A service for sending broadcast and unicast messages over a network."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default service properties for broadcasting
|
||||
kwargs["name"] = "BroadcastService"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
# Implement state description for the service
|
||||
pass
|
||||
|
||||
def unicast(self, ip_address: IPv4Address):
|
||||
# Send a unicast payload to a specific IP address
|
||||
super().send(
|
||||
payload="unicast",
|
||||
dest_ip_address=ip_address,
|
||||
dest_port=Port.HTTP,
|
||||
)
|
||||
|
||||
def broadcast(self, ip_network: IPv4Network):
|
||||
# Send a broadcast payload to an entire IP network
|
||||
super().send(
|
||||
payload="broadcast",
|
||||
dest_ip_address=ip_network,
|
||||
dest_port=Port.HTTP,
|
||||
)
|
||||
|
||||
|
||||
class BroadcastClient(Application):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default client properties
|
||||
kwargs["name"] = "BroadcastClient"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
# Implement state description for the application
|
||||
pass
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
# Append received payloads to the list and print a message
|
||||
self.payloads_received.append(payload)
|
||||
print(f"Payload: {payload} received on node {self.sys_log.hostname}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def broadcast_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
client_1 = Computer(
|
||||
hostname="client_1",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
client_1.power_on()
|
||||
client_1.software_manager.install(BroadcastClient)
|
||||
application_1 = client_1.software_manager.software["BroadcastClient"]
|
||||
application_1.run()
|
||||
|
||||
client_2 = Computer(
|
||||
hostname="client_2",
|
||||
ip_address="192.168.1.3",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(BroadcastClient)
|
||||
application_2 = client_2.software_manager.software["BroadcastClient"]
|
||||
application_2.run()
|
||||
|
||||
server_1 = Server(
|
||||
hostname="server_1",
|
||||
ip_address="192.168.1.1",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
server_1.power_on()
|
||||
|
||||
server_1.software_manager.install(BroadcastService)
|
||||
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
|
||||
service.start()
|
||||
|
||||
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
|
||||
switch_1.power_on()
|
||||
|
||||
network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1])
|
||||
network.connect(endpoint_a=client_2.ethernet_port[1], endpoint_b=switch_1.switch_ports[2])
|
||||
network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[3])
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
|
||||
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
|
||||
"BroadcastClient"
|
||||
]
|
||||
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
|
||||
"BroadcastClient"
|
||||
]
|
||||
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
|
||||
"BroadcastService"
|
||||
]
|
||||
|
||||
return service, client_1, client_2
|
||||
|
||||
|
||||
def test_broadcast_correct_subnet(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.broadcast(IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert client_1.payloads_received == ["broadcast"]
|
||||
assert client_2.payloads_received == ["broadcast"]
|
||||
|
||||
|
||||
def test_broadcast_incorrect_subnet(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.broadcast(IPv4Network("192.168.2.0/24"))
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
|
||||
def test_unicast_correct_address(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.unicast(IPv4Address("192.168.1.2"))
|
||||
|
||||
assert client_1.payloads_received == ["unicast"]
|
||||
assert not client_2.payloads_received
|
||||
|
||||
|
||||
def test_unicast_incorrect_address(broadcast_service_and_clients):
|
||||
service, client_1, client_2 = broadcast_service_and_clients
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
|
||||
service.unicast(IPv4Address("192.168.2.2"))
|
||||
|
||||
assert not client_1.payloads_received
|
||||
assert not client_2.payloads_received
|
||||
@@ -1,11 +1,16 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -34,6 +39,69 @@ def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]:
|
||||
return pc_a, pc_b, router_1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_hop_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
# Configure PC A
|
||||
pc_a = Computer(
|
||||
hostname="pc_a",
|
||||
ip_address="192.168.0.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.0.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_a.power_on()
|
||||
network.add_node(pc_a)
|
||||
|
||||
# Configure Router 1
|
||||
router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1.power_on()
|
||||
network.add_node(router_1)
|
||||
|
||||
# Configure the connection between PC A and Router 1 port 2
|
||||
router_1.configure_port(2, "192.168.0.1", "255.255.255.0")
|
||||
network.connect(pc_a.ethernet_port[1], router_1.ethernet_ports[2])
|
||||
router_1.enable_port(2)
|
||||
|
||||
# Configure Router 1 ACLs
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
# Configure PC B
|
||||
pc_b = Computer(
|
||||
hostname="pc_b",
|
||||
ip_address="192.168.2.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.2.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_b.power_on()
|
||||
network.add_node(pc_b)
|
||||
|
||||
# Configure Router 2
|
||||
router_2 = Router(hostname="router_2", start_up_duration=0)
|
||||
router_2.power_on()
|
||||
network.add_node(router_2)
|
||||
|
||||
# Configure the connection between PC B and Router 2 port 2
|
||||
router_2.configure_port(2, "192.168.2.1", "255.255.255.0")
|
||||
network.connect(pc_b.ethernet_port[1], router_2.ethernet_ports[2])
|
||||
router_2.enable_port(2)
|
||||
|
||||
# Configure Router 2 ACLs
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
# Configure the connection between Router 1 port 1 and Router 2 port 1
|
||||
router_2.configure_port(1, "192.168.1.2", "255.255.255.252")
|
||||
router_1.configure_port(1, "192.168.1.1", "255.255.255.252")
|
||||
network.connect(router_1.ethernet_ports[1], router_2.ethernet_ports[1])
|
||||
router_1.enable_port(1)
|
||||
router_2.enable_port(1)
|
||||
return network
|
||||
|
||||
|
||||
def test_ping_default_gateway(pc_a_pc_b_router_1):
|
||||
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
|
||||
|
||||
@@ -50,3 +118,68 @@ def test_host_on_other_subnet(pc_a_pc_b_router_1):
|
||||
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
|
||||
|
||||
assert pc_a.ping("192.168.1.10")
|
||||
|
||||
|
||||
def test_no_route_no_ping(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
assert not pc_a.ping(pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
|
||||
def test_with_routes_can_ping(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
|
||||
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
|
||||
|
||||
# Configure Route from Router 1 to PC B subnet
|
||||
router_1.route_table.add_route(
|
||||
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
|
||||
)
|
||||
|
||||
# Configure Route from Router 2 to PC A subnet
|
||||
router_2.route_table.add_route(
|
||||
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
|
||||
)
|
||||
|
||||
assert pc_a.ping(pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
|
||||
def test_routing_services(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
|
||||
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
|
||||
|
||||
pc_a.software_manager.install(NTPClient)
|
||||
ntp_client = pc_a.software_manager.software["NTPClient"]
|
||||
ntp_client.start()
|
||||
|
||||
pc_b.software_manager.install(NTPServer)
|
||||
pc_b.software_manager.software["NTPServer"].start()
|
||||
|
||||
ntp_client.configure(ntp_server_ip_address=pc_b.ethernet_port[1].ip_address)
|
||||
|
||||
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
|
||||
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21)
|
||||
|
||||
assert ntp_client.time is None
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is None
|
||||
|
||||
# Configure Route from Router 1 to PC B subnet
|
||||
router_1.route_table.add_route(
|
||||
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
|
||||
)
|
||||
|
||||
# Configure Route from Router 2 to PC A subnet
|
||||
router_2.route_table.add_route(
|
||||
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
|
||||
)
|
||||
|
||||
ntp_client.request_time()
|
||||
assert ntp_client.time is not None
|
||||
|
||||
@@ -90,7 +90,7 @@ def test_repeating_dos_attack(dos_bot_and_db_server):
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.health_state_actual = SoftwareHealthState.GOOD
|
||||
db_server_service.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
@@ -121,7 +121,7 @@ def test_non_repeating_dos_attack(dos_bot_and_db_server):
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.health_state_actual = SoftwareHealthState.GOOD
|
||||
db_server_service.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
|
||||
@@ -24,8 +24,8 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
|
||||
return app, computer
|
||||
|
||||
|
||||
def test_service_on_offline_node(application_class):
|
||||
"""Test to check that the service cannot be interacted with when node it is on is off."""
|
||||
def test_application_on_offline_node(application_class):
|
||||
"""Test to check that the application cannot be interacted with when node it is on is off."""
|
||||
computer: Computer = Computer(
|
||||
hostname="test_computer",
|
||||
ip_address="192.168.1.2",
|
||||
@@ -49,8 +49,8 @@ def test_service_on_offline_node(application_class):
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_server_turns_off_service(populated_node):
|
||||
"""Check that the service is turned off when the server is turned off"""
|
||||
def test_server_turns_off_application(populated_node):
|
||||
"""Check that the application is turned off when the server is turned off"""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
@@ -65,8 +65,8 @@ def test_server_turns_off_service(populated_node):
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
|
||||
"""Check that the service cannot be started when the server is off."""
|
||||
def test_application_cannot_be_turned_on_when_computer_is_off(populated_node):
|
||||
"""Check that the application cannot be started when the computer is off."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
@@ -86,8 +86,8 @@ def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_server_turns_on_service(populated_node):
|
||||
"""Check that turning on the server turns on service."""
|
||||
def test_computer_runs_applications(populated_node):
|
||||
"""Check that turning on the computer will turn on applications."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
@@ -109,13 +109,14 @@ def test_server_turns_on_service(populated_node):
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
computer.start_up_duration = 0
|
||||
computer.shut_down_duration = 0
|
||||
|
||||
computer.power_off()
|
||||
for i in range(computer.start_up_duration + 1):
|
||||
computer.apply_timestep(timestep=i)
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
computer.power_on()
|
||||
for i in range(computer.start_up_duration + 1):
|
||||
computer.apply_timestep(timestep=i)
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
@@ -117,13 +117,14 @@ def test_server_turns_on_service(populated_node):
|
||||
assert server.operating_state is NodeOperatingState.ON
|
||||
assert service.operating_state is ServiceOperatingState.RUNNING
|
||||
|
||||
server.start_up_duration = 0
|
||||
server.shut_down_duration = 0
|
||||
|
||||
server.power_off()
|
||||
for i in range(server.start_up_duration + 1):
|
||||
server.apply_timestep(timestep=i)
|
||||
assert server.operating_state is NodeOperatingState.OFF
|
||||
assert service.operating_state is ServiceOperatingState.STOPPED
|
||||
|
||||
server.power_on()
|
||||
for i in range(server.start_up_duration + 1):
|
||||
server.apply_timestep(timestep=i)
|
||||
assert server.operating_state is NodeOperatingState.ON
|
||||
assert service.operating_state is ServiceOperatingState.RUNNING
|
||||
|
||||
@@ -53,12 +53,12 @@ def test_node_os_scan(node, service, application):
|
||||
# TODO implement processes
|
||||
|
||||
# add services to node
|
||||
service.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_service(service=service)
|
||||
assert service.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# add application to node
|
||||
application.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
@@ -101,7 +101,7 @@ def test_node_red_scan(node, service, application):
|
||||
assert service.revealed_to_red is False
|
||||
|
||||
# add application to node
|
||||
application.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.revealed_to_red is False
|
||||
|
||||
|
||||
@@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
|
||||
|
||||
def filter_keys_nested_item(data, keys):
|
||||
stack = [(data, {})]
|
||||
while stack:
|
||||
current, filtered = stack.pop()
|
||||
if isinstance(current, dict):
|
||||
for k, v in current.items():
|
||||
if k in keys:
|
||||
filtered[k] = filter_keys_nested_item(v, keys)
|
||||
elif isinstance(v, (dict, list)):
|
||||
stack.append((v, {}))
|
||||
elif isinstance(current, list):
|
||||
for item in current:
|
||||
stack.append((item, {}))
|
||||
return filtered
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def network(example_network) -> Network:
|
||||
assert len(example_network.routers) is 1
|
||||
@@ -59,10 +75,10 @@ def test_reset_network(network):
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
|
||||
state_before, sort_keys=True, indent=2
|
||||
)
|
||||
# don't worry if UUIDs change
|
||||
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
|
||||
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_creating_container():
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
def test_scan(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
application.scan()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_visible == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_run_application(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_close_application(application):
|
||||
application.run()
|
||||
assert application.operating_state == ApplicationOperatingState.RUNNING
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
application.close()
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_application_describe_states(application):
|
||||
assert application.operating_state == ApplicationOperatingState.CLOSED
|
||||
assert application.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.run()
|
||||
assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
assert SoftwareHealthState.COMPROMISED.value == application.describe_state().get("health_state_actual")
|
||||
|
||||
application.patch()
|
||||
assert SoftwareHealthState.PATCHING.value == application.describe_state().get("health_state_actual")
|
||||
@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
@@ -42,6 +43,7 @@ def test_ftp_client_store_file(ftp_client):
|
||||
"dest_folder_name": "downloads",
|
||||
"dest_file_name": "file.txt",
|
||||
"file_size": 24,
|
||||
"health_status": FileSystemItemHealthStatus.GOOD,
|
||||
},
|
||||
packet_payload_size=24,
|
||||
status_code=FTPStatusCode.OK,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
@@ -41,6 +42,7 @@ def test_ftp_server_store_file(ftp_server):
|
||||
"dest_folder_name": "downloads",
|
||||
"dest_file_name": "file.txt",
|
||||
"file_size": 24,
|
||||
"health_status": FileSystemItemHealthStatus.GOOD,
|
||||
},
|
||||
packet_payload_size=24,
|
||||
)
|
||||
|
||||
@@ -19,55 +19,146 @@ def test_scan(service):
|
||||
|
||||
def test_start_service(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
service.start()
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_stop_service(service):
|
||||
service.start()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.stop()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_pause_and_resume_service(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
service.resume()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
service.pause()
|
||||
assert service.operating_state == ServiceOperatingState.PAUSED
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.resume()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_restart(service):
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
service.restart()
|
||||
# Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.start()
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
service.restart()
|
||||
# Service is RUNNING. Restart should work
|
||||
assert service.operating_state == ServiceOperatingState.RESTARTING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
timestep = 0
|
||||
while service.operating_state == ServiceOperatingState.RESTARTING:
|
||||
service.apply_timestep(timestep)
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
timestep += 1
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_restart_compromised(service):
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
# compromise the service
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.restart()
|
||||
assert service.operating_state == ServiceOperatingState.RESTARTING
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
"""
|
||||
Service should be compromised even after reset.
|
||||
|
||||
Only way to remove compromised status is via patching.
|
||||
"""
|
||||
|
||||
timestep = 0
|
||||
while service.operating_state == ServiceOperatingState.RESTARTING:
|
||||
service.apply_timestep(timestep)
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
timestep += 1
|
||||
|
||||
assert service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
|
||||
def test_compromised_service_remains_compromised(service):
|
||||
"""
|
||||
Tests that a compromised service stays compromised.
|
||||
|
||||
The only way that the service can be uncompromised is by running patch.
|
||||
"""
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.stop()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.disable()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.enable()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.pause()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
service.resume()
|
||||
assert service.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
|
||||
|
||||
def test_service_patching(service):
|
||||
service.start()
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
service.patch()
|
||||
assert service.health_state_actual == SoftwareHealthState.PATCHING
|
||||
|
||||
for i in range(service.patching_duration + 1):
|
||||
service.apply_timestep(i)
|
||||
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_enable_disable(service):
|
||||
service.disable()
|
||||
assert service.operating_state == ServiceOperatingState.DISABLED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
service.enable()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert service.health_state_actual == SoftwareHealthState.UNUSED
|
||||
|
||||
|
||||
def test_overwhelm_service(service):
|
||||
@@ -76,13 +167,13 @@ def test_overwhelm_service(service):
|
||||
|
||||
uuid = str(uuid4())
|
||||
assert service.add_connection(connection_id=uuid) # should be true
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=uuid) # fails because connection already exists
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert service.add_connection(connection_id=str(uuid4())) # succeed
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity
|
||||
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -15,7 +16,11 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
@pytest.fixture(scope="function")
|
||||
def web_server() -> Server:
|
||||
node = Server(
|
||||
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
hostname="web_server",
|
||||
ip_address="192.168.1.10",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
node.software_manager.install(software_class=WebServer)
|
||||
node.software_manager.software.get("WebServer").start()
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.software import Software, SoftwareHealthState
|
||||
|
||||
|
||||
class TestSoftware(Software):
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def software(file_system):
|
||||
return TestSoftware(
|
||||
name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
|
||||
)
|
||||
|
||||
|
||||
def test_software_creation(software):
|
||||
assert software is not None
|
||||
|
||||
|
||||
def test_software_set_health_state(software):
|
||||
assert software.health_state_actual == SoftwareHealthState.UNUSED
|
||||
software.set_health_state(SoftwareHealthState.GOOD)
|
||||
assert software.health_state_actual == SoftwareHealthState.GOOD
|
||||
Reference in New Issue
Block a user