#2350: add tests to check spaces + acl obs test + nmne space changes

This commit is contained in:
Czar Echavez
2024-03-11 17:47:33 +00:00
parent a228a09917
commit cd6d6325db
9 changed files with 193 additions and 12 deletions

View File

@@ -22,8 +22,6 @@ io_settings:
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:

View File

@@ -20,6 +20,8 @@ class NicObservation(AbstractObservation):
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
global CAPTURE_NMNE
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
@@ -47,6 +49,15 @@ class NicObservation(AbstractObservation):
super().__init__()
self.where: Optional[Tuple[str]] = where
global CAPTURE_NMNE
if CAPTURE_NMNE:
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
@@ -128,19 +139,21 @@ class NicObservation(AbstractObservation):
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":

View File

@@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(

View File

@@ -0,0 +1,66 @@
import pytest
from primaite.game.agent.observations.observations import AclObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_acl_observations(simulation):
"""Test the ACL rule observations."""
router: Router = simulation.network.get_node_by_hostname("router_1")
client_1: Computer = simulation.network.get_node_by_hostname("client_1")
server: Computer = simulation.network.get_node_by_hostname("server_1")
# quick set up of ntp
client_1.software_manager.install(NTPClient)
ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient")
ntp_client.configure(server.network_interface.get(1).ip_address)
server.software_manager.install(NTPServer)
# add router acl rule
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
acl_obs = AclObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
node_ip_to_id={},
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
protocols=["TCP", "UDP", "ICMP"],
)
observation_space = acl_obs.observe(simulation.describe_state())
assert observation_space.get(1) is not None
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
router.acl.remove_rule(1)
observation_space = acl_obs.observe(simulation.describe_state())
assert observation_space.get(1) is not None
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0
assert rule_obs.get("permission") == 0
assert rule_obs.get("source_node_id") == 0
assert rule_obs.get("dest_node_id") == 0
assert rule_obs.get("source_port") == 0
assert rule_obs.get("dest_port") == 0
assert rule_obs.get("protocol") == 0

View File

@@ -26,7 +26,7 @@ def test_file_observation(simulation):
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
)
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # good initial
@@ -52,6 +52,8 @@ def test_folder_observation(simulation):
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
)
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("FILES") is not None
assert observation_state.get("health_status") == 1

View File

@@ -0,0 +1,73 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.observations import LinkObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.sim_container import Simulation
@pytest.fixture(scope="function")
def simulation() -> Simulation:
sim = Simulation()
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
# Connect Computer and Server
network.connect(computer.network_interface[1], server.network_interface[1])
# Should be linked
assert next(iter(network.links.values())).is_up
assert computer.ping(server.network_interface.get(1).ip_address)
# set simulation network as example network
sim.network = network
return sim
def test_link_observation(simulation):
"""Test the link observation."""
# get a link
link: Link = next(iter(simulation.network.links.values()))
computer: Computer = simulation.network.get_node_by_hostname("computer")
server: Server = simulation.network.get_node_by_hostname("server")
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
link_obs = LinkObservation(where=["network", "links", link.uuid])
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state.get("PROTOCOLS") is not None
assert observation_state["PROTOCOLS"]["ALL"] == 0
computer.ping(server.network_interface.get(1).ip_address)
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state["PROTOCOLS"]["ALL"] == 1

View File

@@ -1,9 +1,27 @@
from pathlib import Path
from typing import Union
import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.nmne import CAPTURE_NMNE
from primaite.simulator.sim_container import Simulation
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
@pytest.fixture(scope="function")
@@ -24,6 +42,10 @@ def test_nic(simulation):
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4)
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 1 # enabled
assert observation_state.get("nmne") is not None

View File

@@ -2,6 +2,7 @@ import copy
from uuid import uuid4
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -24,6 +25,8 @@ def test_node_observation(simulation):
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
assert node_obs.space["operating_status"] == spaces.Discrete(5)
observation_state = node_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 1 # computer is on

View File

@@ -1,4 +1,5 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -29,6 +30,9 @@ def test_service_observation(simulation):
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"])
assert service_obs.space["operating_status"] == spaces.Discrete(7)
assert service_obs.space["health_status"] == spaces.Discrete(5)
observation_state = service_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 0