#2735 - synced with dev and resolved merge conflicts

This commit is contained in:
Chris McCarthy
2024-08-02 12:55:09 +01:00
16 changed files with 78 additions and 88 deletions

View File

@@ -129,6 +129,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: client
type: computer

View File

@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NodeOperatingState, UserManager
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -264,6 +264,8 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
@@ -539,10 +541,7 @@ class PrimaiteGame:
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):

View File

@@ -11,7 +11,6 @@ from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -109,8 +100,11 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
"A dataclass defining malicious network events to be captured."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
"A dict containing details of the number of malicious events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -168,8 +162,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": self.nmne})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -202,7 +196,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -223,27 +217,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -252,8 +246,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -1848,7 +1842,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
]
)
print(table)

View File

@@ -1,48 +1,25 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Final, List
from typing import List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
from pydantic import BaseModel, ConfigDict
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
class NMNEConfig(BaseModel):
"""Store all the information to perform NMNE operations."""
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
model_config = ConfigDict(extra="forbid")
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
capture_nmne: bool = False
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
nmne_capture_keywords: List[str] = []
"""List of keywords to identify malicious network events."""
capture_by_direction: bool = True
"""Captures should be organized by traffic direction (inbound/outbound)."""
capture_by_ip_address: bool = False
"""Captures should be organized by source or destination IP address."""
capture_by_protocol: bool = False
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
capture_by_port: bool = False
"""Captures should be organized by source or destination port."""
capture_by_keyword: bool = False
"""Captures should be filtered and categorised based on specific keywords."""

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
from pydantic_core.core_schema import FieldValidationInfo
from pydantic_core.core_schema import ValidationInfo
from primaite import getLogger
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
@field_validator("icmp_code") # noqa
@classmethod
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):

View File

@@ -99,7 +99,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -92,7 +92,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -111,7 +111,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -68,7 +68,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -44,7 +44,7 @@ agents:
num_files: 1
num_nics: 1
include_num_access: false
include_nmne: true
include_nmne: false
- type: LINKS
label: LINKS

View File

@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -120,7 +120,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
rayinit(local_mode=True)
rayinit()
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1

View File

@@ -9,9 +9,11 @@ from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
@@ -75,6 +77,18 @@ def test_nic(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
)
simulation.pre_timestep(0) # apply timestep to whole sim

View File

@@ -1,12 +1,14 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
def test_capture_nmne(uc2_network):
def test_capture_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
assert web_server_nic.nmne == {}
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
def test_describe_state_nmne(uc2_network):
def test_describe_state_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
web_server_nic_state = web_server_nic.describe_state()
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
def test_capture_nmne_observations(uc2_network):
def test_capture_nmne_observations(uc2_network: Network):
"""
Tests the NICObservation class's functionality within a simulated network environment.
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)