#2658: added monitored traffic into config + default obs space
This commit is contained in:
@@ -188,6 +188,11 @@ agents:
|
|||||||
num_nics: 2
|
num_nics: 2
|
||||||
include_num_access: false
|
include_num_access: false
|
||||||
include_nmne: true
|
include_nmne: true
|
||||||
|
monitored_traffic:
|
||||||
|
icmp:
|
||||||
|
- NONE
|
||||||
|
tcp:
|
||||||
|
- DNS
|
||||||
routers:
|
routers:
|
||||||
- hostname: router_1
|
- hostname: router_1
|
||||||
num_ports: 0
|
num_ports: 0
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
"""Number of spaces for network interface observations on this host."""
|
"""Number of spaces for network interface observations on this host."""
|
||||||
include_nmne: Optional[bool] = None
|
include_nmne: Optional[bool] = None
|
||||||
"""Whether network interface observations should include number of malicious network events."""
|
"""Whether network interface observations should include number of malicious network events."""
|
||||||
|
monitored_traffic: Optional[Dict] = None
|
||||||
|
"""A dict containing which traffic types are to be included in the observation."""
|
||||||
include_num_access: Optional[bool] = None
|
include_num_access: Optional[bool] = None
|
||||||
"""Whether to include the number of accesses to files observations on this host."""
|
"""Whether to include the number of accesses to files observations on this host."""
|
||||||
|
|
||||||
@@ -59,6 +61,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
num_files: int,
|
num_files: int,
|
||||||
num_nics: int,
|
num_nics: int,
|
||||||
include_nmne: bool,
|
include_nmne: bool,
|
||||||
|
monitored_traffic: Optional[Dict],
|
||||||
include_num_access: bool,
|
include_num_access: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -87,6 +90,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
:type num_nics: int
|
:type num_nics: int
|
||||||
:param include_nmne: Flag to include network metrics and errors.
|
:param include_nmne: Flag to include network metrics and errors.
|
||||||
:type include_nmne: bool
|
:type include_nmne: bool
|
||||||
|
:param monitored_traffic: Dict which contains the protocol and ports to observe
|
||||||
|
:type monitored_traffic: Dict
|
||||||
:param include_num_access: Flag to include the number of accesses to files.
|
:param include_num_access: Flag to include the number of accesses to files.
|
||||||
:type include_num_access: bool
|
:type include_num_access: bool
|
||||||
"""
|
"""
|
||||||
@@ -123,7 +128,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
|
|
||||||
self.nics: List[NICObservation] = network_interfaces
|
self.nics: List[NICObservation] = network_interfaces
|
||||||
while len(self.nics) < num_nics:
|
while len(self.nics) < num_nics:
|
||||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
|
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
|
||||||
while len(self.nics) > num_nics:
|
while len(self.nics) > num_nics:
|
||||||
truncated_nic = self.nics.pop()
|
truncated_nic = self.nics.pop()
|
||||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||||
@@ -231,7 +236,9 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
# monitor the first N interfaces. Network interface numbering starts at 1.
|
# monitor the first N interfaces. Network interface numbering starts at 1.
|
||||||
count = 1
|
count = 1
|
||||||
while len(nics) < config.num_nics:
|
while len(nics) < config.num_nics:
|
||||||
nic_config = NICObservation.ConfigSchema(nic_num=count, include_nmne=config.include_nmne)
|
nic_config = NICObservation.ConfigSchema(
|
||||||
|
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
|
||||||
|
)
|
||||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -247,5 +254,6 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
|||||||
num_files=config.num_files,
|
num_files=config.num_files,
|
||||||
num_nics=config.num_nics,
|
num_nics=config.num_nics,
|
||||||
include_nmne=config.include_nmne,
|
include_nmne=config.include_nmne,
|
||||||
|
monitored_traffic=config.monitored_traffic,
|
||||||
include_num_access=config.include_num_access,
|
include_num_access=config.include_num_access,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from gymnasium.core import ObsType
|
|||||||
|
|
||||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||||
|
from primaite.simulator.network.transmission.transport_layer import Port
|
||||||
|
|
||||||
|
|
||||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||||
@@ -19,12 +20,10 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
|||||||
"""Number of the network interface."""
|
"""Number of the network interface."""
|
||||||
include_nmne: Optional[bool] = None
|
include_nmne: Optional[bool] = None
|
||||||
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
||||||
|
monitored_traffic: Optional[Dict] = None
|
||||||
|
"""A dict containing which traffic types are to be included in the observation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
|
||||||
self,
|
|
||||||
where: WhereType,
|
|
||||||
include_nmne: bool,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Initialise a network interface observation instance.
|
Initialise a network interface observation instance.
|
||||||
|
|
||||||
@@ -49,6 +48,26 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
|||||||
self.med_nmne_threshold = 5
|
self.med_nmne_threshold = 5
|
||||||
self.low_nmne_threshold = 0
|
self.low_nmne_threshold = 0
|
||||||
|
|
||||||
|
self.monitored_traffic = monitored_traffic
|
||||||
|
if self.monitored_traffic:
|
||||||
|
self.default_observation.update(
|
||||||
|
self._default_monitored_traffic_observation(monitored_traffic_config=monitored_traffic)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict:
|
||||||
|
default_traffic_obs = {"monitored_traffic": {}}
|
||||||
|
|
||||||
|
for protocol in monitored_traffic_config:
|
||||||
|
default_traffic_obs["monitored_traffic"][str(protocol).lower()] = {}
|
||||||
|
|
||||||
|
if protocol == "icmp":
|
||||||
|
default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||||
|
else:
|
||||||
|
for port in monitored_traffic_config[protocol]:
|
||||||
|
default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||||
|
|
||||||
|
return default_traffic_obs
|
||||||
|
|
||||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||||
"""
|
"""
|
||||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||||
@@ -72,6 +91,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
|||||||
return 1
|
return 1
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def _categorise_traffic(self, traffic_value: float, nic_state: Dict) -> int:
|
||||||
|
"""Categorise the traffic into discrete categories."""
|
||||||
|
if traffic_value == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
nic_max_bandwidth = nic_state.get("speed")
|
||||||
|
|
||||||
|
bandwidth_utilisation = traffic_value / nic_max_bandwidth
|
||||||
|
return int(bandwidth_utilisation * 9) + 1
|
||||||
|
|
||||||
def observe(self, state: Dict) -> ObsType:
|
def observe(self, state: Dict) -> ObsType:
|
||||||
"""
|
"""
|
||||||
Generate observation based on the current state of the simulation.
|
Generate observation based on the current state of the simulation.
|
||||||
@@ -87,6 +116,50 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
|||||||
return self.default_observation
|
return self.default_observation
|
||||||
|
|
||||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||||
|
|
||||||
|
# if the observation was configured to monitor traffic from ports/protocols
|
||||||
|
if self.monitored_traffic:
|
||||||
|
obs["monitored_traffic"] = {}
|
||||||
|
|
||||||
|
# iterate through the protocols
|
||||||
|
for protocol in self.monitored_traffic:
|
||||||
|
obs["monitored_traffic"][str(protocol).lower()] = {}
|
||||||
|
# check if the nic has seen traffic with this protocol
|
||||||
|
if nic_state["traffic"].get(protocol):
|
||||||
|
# deal with icmp
|
||||||
|
if str(protocol).lower() == "icmp":
|
||||||
|
obs["monitored_traffic"][protocol] = {
|
||||||
|
"inbound": self._categorise_traffic(
|
||||||
|
traffic_value=nic_state["traffic"]["icmp"]["inbound"], nic_state=nic_state
|
||||||
|
),
|
||||||
|
"outbound": self._categorise_traffic(
|
||||||
|
traffic_value=nic_state["traffic"]["icmp"]["outbound"], nic_state=nic_state
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
for port in self.monitored_traffic[protocol]:
|
||||||
|
port_enum = Port[port]
|
||||||
|
obs["monitored_traffic"][str(protocol).lower()][port_enum.value] = {}
|
||||||
|
traffic = {"inbound": 0, "outbound": 0}
|
||||||
|
|
||||||
|
if nic_state["traffic"][protocol].get(port_enum.value) is not None:
|
||||||
|
traffic = nic_state["traffic"][protocol][port_enum.value]
|
||||||
|
|
||||||
|
obs["monitored_traffic"][protocol][port_enum.value]["inbound"] = self._categorise_traffic(
|
||||||
|
traffic_value=traffic["inbound"], nic_state=nic_state
|
||||||
|
)
|
||||||
|
obs["monitored_traffic"][protocol][port_enum.value]["outbound"] = self._categorise_traffic(
|
||||||
|
traffic_value=traffic["outbound"], nic_state=nic_state
|
||||||
|
)
|
||||||
|
|
||||||
|
# set all the ports under the protocol to 0
|
||||||
|
else:
|
||||||
|
if str(protocol).lower() == "icmp":
|
||||||
|
obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||||
|
else:
|
||||||
|
for port in self.monitored_traffic[protocol]:
|
||||||
|
obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||||
|
|
||||||
if self.include_nmne:
|
if self.include_nmne:
|
||||||
obs.update({"NMNE": {}})
|
obs.update({"NMNE": {}})
|
||||||
direction_dict = nic_state["nmne"].get("direction", {})
|
direction_dict = nic_state["nmne"].get("direction", {})
|
||||||
@@ -128,7 +201,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
|||||||
:return: Constructed network interface observation instance.
|
:return: Constructed network interface observation instance.
|
||||||
:rtype: NICObservation
|
:rtype: NICObservation
|
||||||
"""
|
"""
|
||||||
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
|
return cls(
|
||||||
|
where=parent_where + ["NICs", config.nic_num],
|
||||||
|
include_nmne=config.include_nmne,
|
||||||
|
monitored_traffic=config.monitored_traffic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
|||||||
"""Number of network interface cards (NICs)."""
|
"""Number of network interface cards (NICs)."""
|
||||||
include_nmne: Optional[bool] = None
|
include_nmne: Optional[bool] = None
|
||||||
"""Flag to include nmne."""
|
"""Flag to include nmne."""
|
||||||
|
monitored_traffic: Optional[Dict] = None
|
||||||
|
"""A dict containing which traffic types are to be included in the observation."""
|
||||||
include_num_access: Optional[bool] = None
|
include_num_access: Optional[bool] = None
|
||||||
"""Flag to include the number of accesses."""
|
"""Flag to include the number of accesses."""
|
||||||
num_ports: Optional[int] = None
|
num_ports: Optional[int] = None
|
||||||
@@ -180,6 +182,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
|||||||
host_config.num_nics = config.num_nics
|
host_config.num_nics = config.num_nics
|
||||||
if host_config.include_nmne is None:
|
if host_config.include_nmne is None:
|
||||||
host_config.include_nmne = config.include_nmne
|
host_config.include_nmne = config.include_nmne
|
||||||
|
if host_config.monitored_traffic is None:
|
||||||
|
host_config.monitored_traffic = config.monitored_traffic
|
||||||
if host_config.include_num_access is None:
|
if host_config.include_num_access is None:
|
||||||
host_config.include_num_access = config.include_num_access
|
host_config.include_num_access = config.include_num_access
|
||||||
|
|
||||||
|
|||||||
@@ -277,12 +277,12 @@ class NetworkInterface(SimComponent, ABC):
|
|||||||
if protocol != IPProtocol.ICMP:
|
if protocol != IPProtocol.ICMP:
|
||||||
if port not in self.traffic[protocol]:
|
if port not in self.traffic[protocol]:
|
||||||
self.traffic[protocol][port] = {"inbound": 0, "outbound": 0}
|
self.traffic[protocol][port] = {"inbound": 0, "outbound": 0}
|
||||||
self.traffic[protocol][port][direction] += frame.size
|
self.traffic[protocol][port][direction] += frame.size_Mbits
|
||||||
else:
|
else:
|
||||||
# Handle ICMP protocol separately (ICMP does not use ports)
|
# Handle ICMP protocol separately (ICMP does not use ports)
|
||||||
if not self.traffic[protocol]:
|
if not self.traffic[protocol]:
|
||||||
self.traffic[protocol] = {"inbound": 0, "outbound": 0}
|
self.traffic[protocol] = {"inbound": 0, "outbound": 0}
|
||||||
self.traffic[protocol][direction] += frame.size
|
self.traffic[protocol][direction] += frame.size_Mbits
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def send_frame(self, frame: Frame) -> bool:
|
def send_frame(self, frame: Frame) -> bool:
|
||||||
@@ -325,6 +325,11 @@ class NetworkInterface(SimComponent, ABC):
|
|||||||
"""
|
"""
|
||||||
super().apply_timestep(timestep=timestep)
|
super().apply_timestep(timestep=timestep)
|
||||||
|
|
||||||
|
def pre_timestep(self, timestep: int) -> None:
|
||||||
|
"""Apply pre-timestep logic."""
|
||||||
|
super().pre_timestep(timestep)
|
||||||
|
self.traffic = {}
|
||||||
|
|
||||||
|
|
||||||
class WiredNetworkInterface(NetworkInterface, ABC):
|
class WiredNetworkInterface(NetworkInterface, ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -84,6 +84,11 @@ agents:
|
|||||||
num_files: 1
|
num_files: 1
|
||||||
num_nics: 2
|
num_nics: 2
|
||||||
include_num_access: false
|
include_num_access: false
|
||||||
|
monitored_traffic:
|
||||||
|
icmp:
|
||||||
|
- NONE
|
||||||
|
tcp:
|
||||||
|
- DNS
|
||||||
include_nmne: true
|
include_nmne: true
|
||||||
routers:
|
routers:
|
||||||
- hostname: router_1
|
- hostname: router_1
|
||||||
|
|||||||
@@ -5,12 +5,18 @@ import pytest
|
|||||||
import yaml
|
import yaml
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from primaite.game.agent.interface import ProxyAgent
|
||||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||||
from primaite.game.game import PrimaiteGame
|
from primaite.game.game import PrimaiteGame
|
||||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||||
from primaite.simulator.sim_container import Simulation
|
from primaite.simulator.sim_container import Simulation
|
||||||
|
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||||
|
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||||
|
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||||
|
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||||
|
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||||
from tests import TEST_ASSETS_ROOT
|
from tests import TEST_ASSETS_ROOT
|
||||||
|
|
||||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||||
@@ -31,6 +37,32 @@ def simulation(example_network) -> Simulation:
|
|||||||
# set simulation network as example network
|
# set simulation network as example network
|
||||||
sim.network = example_network
|
sim.network = example_network
|
||||||
|
|
||||||
|
computer: Computer = example_network.get_node_by_hostname("client_1")
|
||||||
|
server: Server = example_network.get_node_by_hostname("server_1")
|
||||||
|
|
||||||
|
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||||
|
web_browser.run()
|
||||||
|
|
||||||
|
# Install DNS Client service on computer
|
||||||
|
computer.software_manager.install(DNSClient)
|
||||||
|
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
|
||||||
|
# set dns server
|
||||||
|
dns_client.dns_server = server.network_interface[1].ip_address
|
||||||
|
|
||||||
|
# Install Web Server service on server
|
||||||
|
server.software_manager.install(WebServer)
|
||||||
|
web_server_service: WebServer = server.software_manager.software.get("WebServer")
|
||||||
|
web_server_service.start()
|
||||||
|
|
||||||
|
# Install DNS Server service on server
|
||||||
|
server.software_manager.install(DNSServer)
|
||||||
|
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
|
||||||
|
# register arcd.com to DNS
|
||||||
|
dns_server.dns_register(
|
||||||
|
domain_name="arcd.com",
|
||||||
|
domain_ip_address=server.network_interfaces[next(iter(server.network_interfaces))].ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
return sim
|
return sim
|
||||||
|
|
||||||
|
|
||||||
@@ -102,3 +134,71 @@ def test_config_nic_categories(simulation):
|
|||||||
high_nmne_threshold=9,
|
high_nmne_threshold=9,
|
||||||
include_nmne=True,
|
include_nmne=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nic_monitored_traffic(simulation):
|
||||||
|
monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]}
|
||||||
|
|
||||||
|
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||||
|
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
|
||||||
|
|
||||||
|
nic_obs = NICObservation(
|
||||||
|
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
|
||||||
|
)
|
||||||
|
|
||||||
|
simulation.pre_timestep(0) # apply timestep to whole sim
|
||||||
|
simulation.apply_timestep(0) # apply timestep to whole sim
|
||||||
|
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||||
|
|
||||||
|
assert traffic_obs["icmp"]["inbound"] == 0
|
||||||
|
assert traffic_obs["icmp"]["outbound"] == 0
|
||||||
|
|
||||||
|
# send a ping
|
||||||
|
pc.ping(target_ip_address=pc2.network_interface[1].ip_address)
|
||||||
|
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||||
|
|
||||||
|
assert traffic_obs["icmp"]["inbound"] == 1
|
||||||
|
assert traffic_obs["icmp"]["outbound"] == 1
|
||||||
|
|
||||||
|
simulation.pre_timestep(1) # apply timestep to whole sim
|
||||||
|
simulation.apply_timestep(1) # apply timestep to whole sim
|
||||||
|
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||||
|
|
||||||
|
assert traffic_obs["icmp"]["inbound"] == 0
|
||||||
|
assert traffic_obs["icmp"]["outbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["inbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["outbound"] == 0
|
||||||
|
|
||||||
|
# send a database query
|
||||||
|
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
|
||||||
|
browser.target_url = f"http://arcd.com/"
|
||||||
|
browser.get_webpage()
|
||||||
|
|
||||||
|
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||||
|
assert traffic_obs["icmp"]["inbound"] == 0
|
||||||
|
assert traffic_obs["icmp"]["outbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["inbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out
|
||||||
|
|
||||||
|
simulation.pre_timestep(2) # apply timestep to whole sim
|
||||||
|
simulation.apply_timestep(2) # apply timestep to whole sim
|
||||||
|
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||||
|
|
||||||
|
assert traffic_obs["icmp"]["inbound"] == 0
|
||||||
|
assert traffic_obs["icmp"]["outbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["inbound"] == 0
|
||||||
|
assert traffic_obs["tcp"][53]["outbound"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_nic_monitored_traffic_config():
|
||||||
|
"""Test that the config loads the monitored traffic config correctly."""
|
||||||
|
game: PrimaiteGame = load_config(BASIC_CONFIG)
|
||||||
|
|
||||||
|
# should have icmp and DNS
|
||||||
|
defender_agent: ProxyAgent = game.agents.get("defender")
|
||||||
|
cur_obs = defender_agent.observation_manager.current_observation
|
||||||
|
|
||||||
|
assert cur_obs["NODES"]["HOST0"]["NICS"][1]["monitored_traffic"] == {
|
||||||
|
"icmp": {"inbound": 0, "outbound": 0},
|
||||||
|
"tcp": {53: {"inbound": 0, "outbound": 0}},
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ def test_host_observation(simulation):
|
|||||||
num_services=1,
|
num_services=1,
|
||||||
include_num_access=False,
|
include_num_access=False,
|
||||||
include_nmne=False,
|
include_nmne=False,
|
||||||
|
monitored_traffic=None,
|
||||||
services=[],
|
services=[],
|
||||||
applications=[],
|
applications=[],
|
||||||
folders=[],
|
folders=[],
|
||||||
|
|||||||
Reference in New Issue
Block a user