#2735 - synced with dev and resolved merge conflicts
This commit is contained in:
@@ -129,6 +129,10 @@ agents:
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
|
||||
@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.airspace import AirSpaceFrequency
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState, UserManager
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
|
||||
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -264,6 +264,8 @@ class PrimaiteGame:
|
||||
|
||||
nodes_cfg = network_config.get("nodes", [])
|
||||
links_cfg = network_config.get("links", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
@@ -539,10 +541,7 @@ class PrimaiteGame:
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
|
||||
import primaite.simulator.network.nmne
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.nmne import (
|
||||
CAPTURE_BY_DIRECTION,
|
||||
CAPTURE_BY_IP_ADDRESS,
|
||||
CAPTURE_BY_KEYWORD,
|
||||
CAPTURE_BY_PORT,
|
||||
CAPTURE_BY_PROTOCOL,
|
||||
CAPTURE_NMNE,
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -109,8 +100,11 @@ class NetworkInterface(SimComponent, ABC):
|
||||
pcap: Optional[PacketCapture] = None
|
||||
"A PacketCapture instance for capturing and analysing packets passing through this interface."
|
||||
|
||||
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
|
||||
"A dataclass defining malicious network events to be captured."
|
||||
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
"A dict containing details of the number of malicious events captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
@@ -168,8 +162,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
if self.nmne_config and self.nmne_config.capture_nmne:
|
||||
state.update({"nmne": self.nmne})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@@ -202,7 +196,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
|
||||
"""
|
||||
# Exit function if NMNE capturing is disabled
|
||||
if not CAPTURE_NMNE:
|
||||
if not (self.nmne_config and self.nmne_config.capture_nmne):
|
||||
return
|
||||
|
||||
# Initialise basic frame data variables
|
||||
@@ -223,27 +217,27 @@ class NetworkInterface(SimComponent, ABC):
|
||||
frame_str = str(frame.payload)
|
||||
|
||||
# Proceed only if any NMNE keyword is present in the frame payload
|
||||
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
|
||||
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
|
||||
# Start with the root of the NMNE capture structure
|
||||
current_level = self.nmne
|
||||
|
||||
# Update NMNE structure based on enabled settings
|
||||
if CAPTURE_BY_DIRECTION:
|
||||
if self.nmne_config.capture_by_direction:
|
||||
# Set or get the dictionary for the current direction
|
||||
current_level = current_level.setdefault("direction", {})
|
||||
current_level = current_level.setdefault(direction, {})
|
||||
|
||||
if CAPTURE_BY_IP_ADDRESS:
|
||||
if self.nmne_config.capture_by_ip_address:
|
||||
# Set or get the dictionary for the current IP address
|
||||
current_level = current_level.setdefault("ip_address", {})
|
||||
current_level = current_level.setdefault(ip_address, {})
|
||||
|
||||
if CAPTURE_BY_PROTOCOL:
|
||||
if self.nmne_config.capture_by_protocol:
|
||||
# Set or get the dictionary for the current protocol
|
||||
current_level = current_level.setdefault("protocol", {})
|
||||
current_level = current_level.setdefault(protocol, {})
|
||||
|
||||
if CAPTURE_BY_PORT:
|
||||
if self.nmne_config.capture_by_port:
|
||||
# Set or get the dictionary for the current port
|
||||
current_level = current_level.setdefault("port", {})
|
||||
current_level = current_level.setdefault(port, {})
|
||||
@@ -252,8 +246,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
keyword_level = current_level.setdefault("keywords", {})
|
||||
|
||||
# Increment the count for detected keywords in the payload
|
||||
if CAPTURE_BY_KEYWORD:
|
||||
for keyword in NMNE_CAPTURE_KEYWORDS:
|
||||
if self.nmne_config.capture_by_keyword:
|
||||
for keyword in self.nmne_config.nmne_capture_keywords:
|
||||
if keyword in frame_str:
|
||||
# Update the count for each keyword found
|
||||
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
|
||||
@@ -1848,7 +1842,7 @@ class Node(SimComponent):
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
|
||||
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -1,48 +1,25 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, Final, List
|
||||
from typing import List
|
||||
|
||||
CAPTURE_NMNE: bool = True
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
|
||||
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination IP address."""
|
||||
CAPTURE_BY_PROTOCOL: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
CAPTURE_BY_PORT: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination port."""
|
||||
CAPTURE_BY_KEYWORD: Final[bool] = False
|
||||
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def set_nmne_config(nmne_config: Dict):
|
||||
"""
|
||||
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
|
||||
class NMNEConfig(BaseModel):
|
||||
"""Store all the information to perform NMNE operations."""
|
||||
|
||||
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
|
||||
keywords to use for identifying NMNEs.
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
|
||||
and maintains type integrity by checking the types of the provided values.
|
||||
|
||||
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
|
||||
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
|
||||
to specify keywords for NMNE identification.
|
||||
"""
|
||||
global NMNE_CAPTURE_KEYWORDS
|
||||
global CAPTURE_NMNE
|
||||
|
||||
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
|
||||
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
|
||||
if not isinstance(CAPTURE_NMNE, bool):
|
||||
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
|
||||
|
||||
# Update the NMNE capture keywords, appending new keywords if provided
|
||||
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
|
||||
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
|
||||
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
|
||||
capture_nmne: bool = False
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
|
||||
nmne_capture_keywords: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
capture_by_direction: bool = True
|
||||
"""Captures should be organized by traffic direction (inbound/outbound)."""
|
||||
capture_by_ip_address: bool = False
|
||||
"""Captures should be organized by source or destination IP address."""
|
||||
capture_by_protocol: bool = False
|
||||
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
capture_by_port: bool = False
|
||||
"""Captures should be organized by source or destination port."""
|
||||
capture_by_keyword: bool = False
|
||||
"""Captures should be filtered and categorised based on specific keywords."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, field_validator, validate_call
|
||||
from pydantic_core.core_schema import FieldValidationInfo
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
|
||||
|
||||
@field_validator("icmp_code") # noqa
|
||||
@classmethod
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
|
||||
"""Validates the icmp_type and icmp_code."""
|
||||
icmp_type = info.data["icmp_type"]
|
||||
if get_icmp_type_code_description(icmp_type, v):
|
||||
|
||||
@@ -99,7 +99,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -92,7 +92,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -111,7 +111,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -68,7 +68,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -89,7 +89,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -44,7 +44,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 1
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
|
||||
@@ -89,7 +89,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -120,7 +120,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
rayinit(local_mode=True)
|
||||
rayinit()
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
@@ -9,9 +9,11 @@ from gymnasium import spaces
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
@@ -75,6 +77,18 @@ def test_nic(simulation):
|
||||
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
|
||||
nmne_config = {
|
||||
"capture_nmne": True, # Enable the capture of MNEs
|
||||
"nmne_capture_keywords": [
|
||||
"DELETE",
|
||||
"ENCRYPT",
|
||||
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
|
||||
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
|
||||
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
|
||||
)
|
||||
|
||||
simulation.pre_timestep(0) # apply timestep to whole sim
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
def test_capture_nmne(uc2_network):
|
||||
def test_capture_nmne(uc2_network: Network):
|
||||
"""
|
||||
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
|
||||
|
||||
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Assert that initially, there are no captured MNEs on both web and database servers
|
||||
assert web_server_nic.nmne == {}
|
||||
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
|
||||
|
||||
|
||||
def test_describe_state_nmne(uc2_network):
|
||||
def test_describe_state_nmne(uc2_network: Network):
|
||||
"""
|
||||
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
|
||||
|
||||
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Assert that initially, there are no captured MNEs on both web and database servers
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
|
||||
|
||||
|
||||
def test_capture_nmne_observations(uc2_network):
|
||||
def test_capture_nmne_observations(uc2_network: Network):
|
||||
"""
|
||||
Tests the NICObservation class's functionality within a simulated network environment.
|
||||
|
||||
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Define observations for the NICs of the database and web servers
|
||||
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)
|
||||
|
||||
Reference in New Issue
Block a user