Merge '2887-Align_Node_Types' into 3062-discriminators

This commit is contained in:
Marek Wolan
2025-02-04 14:04:40 +00:00
78 changed files with 1429 additions and 832 deletions

View File

@@ -25,6 +25,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.config.shut_down_duration = 3
assert client_1.operating_state == NodeOperatingState.ON
@@ -35,13 +36,15 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
for i in range(client_1.config.shut_down_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
client_1.config.start_up_duration = 3
# turn it on
action = ("node-startup", {"node_name": "client_1"})
agent.store_action(action)
@@ -49,7 +52,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.BOOTING
for i in range(client_1.start_up_duration + 1):
for i in range(client_1.config.start_up_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()
@@ -79,7 +82,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture:
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
for i in range(client_1.config.shut_down_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()

View File

@@ -36,7 +36,7 @@ def test_acl_observations(simulation):
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1)
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
where=["network", "nodes", router.config.hostname, "acl", "acl"],
ip_list=[],
port_list=[123, 80, 5432],
protocol_list=["tcp", "udp", "icmp"],

View File

@@ -24,7 +24,7 @@ def test_file_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png")
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
@@ -52,7 +52,7 @@ def test_folder_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
file_system_requires_scan=True,
num_files=1,

View File

@@ -25,7 +25,8 @@ def check_default_rules(acl_obs):
def test_firewall_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON)
firewall_cfg = {"type": "firewall", "hostname": "firewall"}
firewall = Firewall.from_config(config=firewall_cfg)
firewall_observation = FirewallObservation(
where=[],
num_rules=7,
@@ -116,7 +117,9 @@ def test_firewall_observation():
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
# connect a switch to the firewall and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"}
)
link = net.connect(firewall.network_interface[1], switch.network_interface[1])
assert firewall.network_interface[1].enabled
observation = firewall_observation.observe(firewall.describe_state())

View File

@@ -56,12 +56,26 @@ def test_link_observation():
"""Check the shape and contents of the link observation."""
net = Network()
sim = Simulation(network=net)
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
computer_1 = Computer(
hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"}
)
computer_2 = Computer(
hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0
computer_1: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer_1",
"ip_address": "10.0.0.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer_2: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer_2",
"ip_address": "10.0.0.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer_1.power_on()
computer_2.power_on()

View File

@@ -75,7 +75,7 @@ def test_nic(simulation):
nic: NIC = pc.network_interface[1]
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
@@ -108,7 +108,7 @@ def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
@@ -163,7 +163,9 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
where=["network", "nodes", pc.config.hostname, "NICs", 1],
include_nmne=False,
monitored_traffic=monitored_traffic,
)
simulation.pre_timestep(0) # apply timestep to whole sim

View File

@@ -25,7 +25,7 @@ def test_host_observation(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
host_obs = HostObservation(
where=["network", "nodes", pc.hostname],
where=["network", "nodes", pc.config.hostname],
num_applications=0,
num_files=1,
num_folders=1,
@@ -56,7 +56,7 @@ def test_host_observation(simulation):
observation_state = host_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 4 # shutting down
for i in range(pc.shut_down_duration + 1):
for i in range(pc.config.shut_down_duration + 1):
pc.apply_timestep(i)
observation_state = host_obs.observe(simulation.describe_state())

View File

@@ -16,7 +16,9 @@ from primaite.utils.validation.port import PORT_LOOKUP
def test_router_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON)
router = Router.from_config(
config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": "ON"}
)
ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)]
acl = ACLObservation(
@@ -89,7 +91,9 @@ def test_router_observation():
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
# connect a switch to the router and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"}
)
link = net.connect(router.network_interface[1], switch.network_interface[1])
assert router.network_interface[1].enabled
observed_output = router_observation.observe(router.describe_state())

View File

@@ -29,7 +29,7 @@ def test_service_observation(simulation):
ntp_server = pc.software_manager.software.get("ntp-server")
assert ntp_server
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "ntp-server"])
service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"])
assert service_obs.space["operating_status"] == spaces.Discrete(7)
assert service_obs.space["health_status"] == spaces.Discrete(5)
@@ -54,7 +54,7 @@ def test_application_observation(simulation):
web_browser: WebBrowser = pc.software_manager.software.get("web-browser")
assert web_browser
app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "web-browser"])
app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"])
web_browser.close()
observation_state = app_obs.observe(simulation.describe_state())

View File

@@ -2,6 +2,7 @@
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.system.services.service import ServiceOperatingState
from tests.conftest import TEST_ASSETS_ROOT

View File

@@ -21,6 +21,7 @@ from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.software import SoftwareHealthState

View File

@@ -8,14 +8,16 @@ from primaite.simulator.sim_container import Simulation
def test_file_observation():
sim = Simulation()
pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0")
pc: Computer = Computer.from_config(
config={"type": "computer", "hostname": "beep", "ip_address": "123.123.123.123", "subnet_mask": "255.255.255.0"}
)
sim.network.add_node(pc)
f = pc.file_system.create_file(file_name="dog.png")
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=False,
)