#2350: add tests to check spaces + acl obs test + nmne space changes
This commit is contained in:
@@ -22,8 +22,6 @@ io_settings:
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
|
||||
@@ -20,6 +20,8 @@ class NicObservation(AbstractObservation):
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
@@ -47,6 +49,15 @@ class NicObservation(AbstractObservation):
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
@@ -128,19 +139,21 @@ class NicObservation(AbstractObservation):
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"nic_status": spaces.Discrete(3),
|
||||
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
|
||||
}
|
||||
)
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
|
||||
@@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation):
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.observations.observations import AclObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_acl_observations(simulation):
|
||||
"""Test the ACL rule observations."""
|
||||
router: Router = simulation.network.get_node_by_hostname("router_1")
|
||||
client_1: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
server: Computer = simulation.network.get_node_by_hostname("server_1")
|
||||
|
||||
# quick set up of ntp
|
||||
client_1.software_manager.install(NTPClient)
|
||||
ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient")
|
||||
ntp_client.configure(server.network_interface.get(1).ip_address)
|
||||
server.software_manager.install(NTPServer)
|
||||
|
||||
# add router acl rule
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
|
||||
|
||||
acl_obs = AclObservation(
|
||||
where=["network", "nodes", router.hostname, "acl", "acl"],
|
||||
node_ip_to_id={},
|
||||
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
)
|
||||
|
||||
observation_space = acl_obs.observe(simulation.describe_state())
|
||||
assert observation_space.get(1) is not None
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
|
||||
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
|
||||
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
|
||||
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
|
||||
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
|
||||
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
|
||||
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
|
||||
|
||||
router.acl.remove_rule(1)
|
||||
|
||||
observation_space = acl_obs.observe(simulation.describe_state())
|
||||
assert observation_space.get(1) is not None
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0
|
||||
assert rule_obs.get("permission") == 0
|
||||
assert rule_obs.get("source_node_id") == 0
|
||||
assert rule_obs.get("dest_node_id") == 0
|
||||
assert rule_obs.get("source_port") == 0
|
||||
assert rule_obs.get("dest_port") == 0
|
||||
assert rule_obs.get("protocol") == 0
|
||||
@@ -26,7 +26,7 @@ def test_file_observation(simulation):
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
|
||||
)
|
||||
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
|
||||
|
||||
observation_state = dog_file_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 1 # good initial
|
||||
@@ -52,6 +52,8 @@ def test_folder_observation(simulation):
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
|
||||
)
|
||||
|
||||
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
|
||||
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("FILES") is not None
|
||||
assert observation_state.get("health_status") == 1
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import LinkObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation() -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
network = Network()
|
||||
|
||||
# Create Computer
|
||||
computer = Computer(
|
||||
hostname="computer",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
computer.power_on()
|
||||
|
||||
# Create Server
|
||||
server = Server(
|
||||
hostname="server",
|
||||
ip_address="192.168.1.3",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
server.power_on()
|
||||
|
||||
# Connect Computer and Server
|
||||
network.connect(computer.network_interface[1], server.network_interface[1])
|
||||
|
||||
# Should be linked
|
||||
assert next(iter(network.links.values())).is_up
|
||||
|
||||
assert computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_link_observation(simulation):
|
||||
"""Test the link observation."""
|
||||
# get a link
|
||||
link: Link = next(iter(simulation.network.links.values()))
|
||||
|
||||
computer: Computer = simulation.network.get_node_by_hostname("computer")
|
||||
server: Server = simulation.network.get_node_by_hostname("server")
|
||||
|
||||
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
|
||||
|
||||
link_obs = LinkObservation(where=["network", "links", link.uuid])
|
||||
|
||||
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("PROTOCOLS") is not None
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 1
|
||||
@@ -1,9 +1,27 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -24,6 +42,10 @@ def test_nic(simulation):
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4)
|
||||
assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4)
|
||||
|
||||
observation_state = nic_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("nic_status") == 1 # enabled
|
||||
assert observation_state.get("nmne") is not None
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -24,6 +25,8 @@ def test_node_observation(simulation):
|
||||
|
||||
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
|
||||
|
||||
assert node_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 1 # computer is on
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -29,6 +30,9 @@ def test_service_observation(simulation):
|
||||
|
||||
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"])
|
||||
|
||||
assert service_obs.space["operating_status"] == spaces.Discrete(7)
|
||||
assert service_obs.space["health_status"] == spaces.Discrete(5)
|
||||
|
||||
observation_state = service_obs.observe(simulation.describe_state())
|
||||
|
||||
assert observation_state.get("health_status") == 0
|
||||
|
||||
Reference in New Issue
Block a user