487 lines
18 KiB
Python
487 lines
18 KiB
Python
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import pytest
|
|
import yaml
|
|
from pydantic import Field
|
|
from ray import init as rayinit
|
|
|
|
from primaite import getLogger, PRIMAITE_PATHS
|
|
from primaite.game.agent.actions import ActionManager
|
|
from primaite.game.agent.interface import AbstractAgent
|
|
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
|
from primaite.game.agent.rewards import RewardFunction
|
|
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
|
from primaite.game.game import PrimaiteGame
|
|
from primaite.simulator.file_system.file_system import FileSystem
|
|
from primaite.simulator.network.container import Network
|
|
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
|
from primaite.simulator.network.hardware.nodes.host.server import Server
|
|
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
|
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
|
from primaite.simulator.network.networks import arcd_uc2_network
|
|
from primaite.simulator.sim_container import Simulation
|
|
from primaite.simulator.system.applications.application import Application
|
|
from primaite.simulator.system.applications.web_browser import WebBrowser
|
|
from primaite.simulator.system.core.sys_log import SysLog
|
|
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
|
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
|
from primaite.simulator.system.services.service import Service
|
|
from primaite.simulator.system.services.web_server.web_server import WebServer
|
|
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
|
from primaite.utils.validation.port import PORT_LOOKUP
|
|
from tests import TEST_ASSETS_ROOT
|
|
|
|
rayinit()
|
|
ACTION_SPACE_NODE_VALUES = 1
|
|
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
|
|
|
_LOGGER = getLogger(__name__)
|
|
|
|
|
|
class DummyService(Service, identifier="DummyService"):
|
|
"""Test Service class"""
|
|
|
|
class ConfigSchema(Service.ConfigSchema):
|
|
"""ConfigSchema for DummyService."""
|
|
|
|
type: str = "DummyService"
|
|
|
|
config: "DummyService.ConfigSchema" = Field(default_factory=lambda: DummyService.ConfigSchema())
|
|
|
|
def describe_state(self) -> Dict:
|
|
return super().describe_state()
|
|
|
|
def __init__(self, **kwargs):
|
|
kwargs["name"] = "DummyService"
|
|
kwargs["port"] = PORT_LOOKUP["HTTP"]
|
|
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
|
super().__init__(**kwargs)
|
|
|
|
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
|
pass
|
|
|
|
|
|
class DummyApplication(Application, identifier="DummyApplication"):
|
|
"""Test Application class"""
|
|
|
|
class ConfigSchema(Application.ConfigSchema):
|
|
"""ConfigSchema for DummyApplication."""
|
|
|
|
type: str = "DummyApplication"
|
|
|
|
config: "DummyApplication.ConfigSchema" = Field(default_factory=lambda: DummyApplication.ConfigSchema())
|
|
|
|
def __init__(self, **kwargs):
|
|
kwargs["name"] = "DummyApplication"
|
|
kwargs["port"] = PORT_LOOKUP["HTTP"]
|
|
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
|
super().__init__(**kwargs)
|
|
|
|
def describe_state(self) -> Dict:
|
|
return super().describe_state()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def uc2_network() -> Network:
|
|
with open(PRIMAITE_PATHS.user_config_path / "example_config" / "data_manipulation.yaml") as f:
|
|
cfg = yaml.safe_load(f)
|
|
game = PrimaiteGame.from_config(cfg)
|
|
return game.simulation.network
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def service(file_system) -> DummyService:
|
|
return DummyService(
|
|
name="DummyService", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service")
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def service_class():
|
|
return DummyService
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def application(file_system) -> DummyApplication:
|
|
return DummyApplication(
|
|
name="DummyApplication",
|
|
port=PORT_LOOKUP["ARP"],
|
|
file_system=file_system,
|
|
sys_log=SysLog(hostname="dummy_application"),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def application_class():
|
|
return DummyApplication
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def file_system() -> FileSystem:
|
|
# computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
|
|
computer_cfg = {"type": "computer",
|
|
"hostname": "fs_node",
|
|
"ip_address": "192.168.1.2",
|
|
"subnet_mask": "255.255.255.0",
|
|
"start_up_duration": 0,
|
|
}
|
|
computer = Computer.from_config(config=computer_cfg)
|
|
computer.power_on()
|
|
return computer.file_system
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_server() -> Tuple[Computer, Server]:
|
|
network = Network()
|
|
|
|
# Create Computer
|
|
computer = Computer(
|
|
hostname="computer",
|
|
ip_address="192.168.1.2",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
computer.power_on()
|
|
|
|
# Create Server
|
|
server = Server(
|
|
hostname="server",
|
|
ip_address="192.168.1.3",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
server.power_on()
|
|
|
|
# Connect Computer and Server
|
|
network.connect(computer.network_interface[1], server.network_interface[1])
|
|
|
|
# Should be linked
|
|
assert next(iter(network.links.values())).is_up
|
|
|
|
return computer, server
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_switch_server() -> Tuple[Computer, Switch, Server]:
|
|
network = Network()
|
|
|
|
# Create Computer
|
|
computer = Computer(
|
|
hostname="computer",
|
|
ip_address="192.168.1.2",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
computer.power_on()
|
|
|
|
# Create Server
|
|
server = Server(
|
|
hostname="server",
|
|
ip_address="192.168.1.3",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
server.power_on()
|
|
|
|
switch = Switch(hostname="switch", start_up_duration=0)
|
|
switch.power_on()
|
|
|
|
network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1])
|
|
network.connect(endpoint_a=server.network_interface[1], endpoint_b=switch.network_interface[2])
|
|
|
|
assert all(link.is_up for link in network.links.values())
|
|
|
|
return computer, switch, server
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def example_network() -> Network:
|
|
"""
|
|
Create the network used for testing.
|
|
|
|
Should only contain the nodes and links.
|
|
This would act as the base network and services and applications are installed in the relevant test file,
|
|
|
|
-------------- --------------
|
|
| client_1 |----- ----| server_1 |
|
|
-------------- | -------------- -------------- -------------- | --------------
|
|
------| switch_2 |------| router_1 |------| switch_1 |------
|
|
-------------- | -------------- -------------- -------------- | --------------
|
|
| client_2 |---- ----| server_2 |
|
|
-------------- --------------
|
|
"""
|
|
network = Network()
|
|
|
|
# Router 1
|
|
|
|
router_1_cfg = {"hostname": "router_1", "type": "router"}
|
|
|
|
# router_1 = Router(hostname="router_1", start_up_duration=0)
|
|
router_1 = Router.from_config(config=router_1_cfg)
|
|
router_1.power_on()
|
|
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
|
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
|
|
|
|
# Switch 1
|
|
|
|
switch_1_cfg = {"hostname": "switch_1", "type": "switch"}
|
|
|
|
switch_1 = Switch.from_config(config=switch_1_cfg)
|
|
|
|
# switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
|
|
switch_1.power_on()
|
|
|
|
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
|
|
router_1.enable_port(1)
|
|
|
|
# Switch 2
|
|
switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8}
|
|
# switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
|
|
switch_2 = Switch.from_config(config=switch_2_config)
|
|
switch_2.power_on()
|
|
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
|
|
router_1.enable_port(2)
|
|
|
|
# # Client 1
|
|
|
|
client_1_cfg = {
|
|
"type": "computer",
|
|
"hostname": "client_1",
|
|
"ip_address": "192.168.10.21",
|
|
"subnet_mask": "255.255.255.0",
|
|
"default_gateway": "192.168.10.1",
|
|
"start_up_duration": 0,
|
|
}
|
|
|
|
client_1 = Computer.from_config(config=client_1_cfg)
|
|
|
|
client_1.power_on()
|
|
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
|
|
|
# # Client 2
|
|
|
|
client_2_cfg = {
|
|
"type": "computer",
|
|
"hostname": "client_2",
|
|
"ip_address": "192.168.10.22",
|
|
"subnet_mask": "255.255.255.0",
|
|
"default_gateway": "192.168.10.1",
|
|
"start_up_duration": 0,
|
|
}
|
|
|
|
client_2 = Computer.from_config(config=client_2_cfg)
|
|
|
|
client_2.power_on()
|
|
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
|
|
|
|
# # Server 1
|
|
|
|
server_1_cfg = {
|
|
"type": "server",
|
|
"hostname": "server_1",
|
|
"ip_address": "192.168.1.10",
|
|
"subnet_mask": "255.255.255.0",
|
|
"default_gateway": "192.168.1.1",
|
|
"start_up_duration": 0,
|
|
}
|
|
|
|
server_1 = Server.from_config(config=server_1_cfg)
|
|
|
|
server_1.power_on()
|
|
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
|
|
|
|
# # DServer 2
|
|
|
|
server_2_cfg = {
|
|
"type": "server",
|
|
"hostname": "server_2",
|
|
"ip_address": "192.168.1.14",
|
|
"subnet_mask": "255.255.255.0",
|
|
"default_gateway": "192.168.1.1",
|
|
"start_up_duration": 0,
|
|
}
|
|
|
|
server_2 = Server.from_config(config=server_2_cfg)
|
|
|
|
server_2.power_on()
|
|
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
|
|
|
|
router_1.acl.add_rule(action=ACLAction.PERMIT, position=1)
|
|
|
|
assert all(link.is_up for link in network.links.values())
|
|
|
|
client_1.software_manager.show()
|
|
|
|
return network
|
|
|
|
|
|
class ControlledAgent(AbstractAgent, identifier="ControlledAgent"):
|
|
"""Agent that can be controlled by the tests."""
|
|
|
|
config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema())
|
|
most_recent_action: Optional[Tuple[str, Dict]] = None
|
|
|
|
class ConfigSchema(AbstractAgent.ConfigSchema):
|
|
"""Configuration Schema for Abstract Agent used in tests."""
|
|
|
|
type: str = "ControlledAgent"
|
|
|
|
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
|
"""Return the agent's most recent action, formatted in CAOS format."""
|
|
return self.most_recent_action
|
|
|
|
def store_action(self, action: Tuple[str, Dict]):
|
|
"""Store the most recent action."""
|
|
self.most_recent_action = action
|
|
|
|
|
|
def install_stuff_to_sim(sim: Simulation):
|
|
"""Create a simulation with a computer, two servers, two switches, and a router."""
|
|
|
|
# 0: Pull out the network
|
|
network = sim.network
|
|
|
|
# 1: Set up network hardware
|
|
# 1.1: Configure the router
|
|
router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0})
|
|
router.power_on()
|
|
router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0")
|
|
router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0")
|
|
|
|
# 1.2: Create and connect switches
|
|
switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0})
|
|
switch_1.power_on()
|
|
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
|
|
router.enable_port(1)
|
|
switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0})
|
|
switch_2.power_on()
|
|
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])
|
|
router.enable_port(2)
|
|
|
|
# 1.3: Create and connect computer
|
|
client_1_cfg = {"type": "computer",
|
|
"hostname": "client_1",
|
|
"ip_address":"10.0.1.2",
|
|
"subnet_mask":"255.255.255.0",
|
|
"default_gateway": "10.0.1.1",
|
|
"start_up_duration":0}
|
|
client_1: Computer = Computer.from_config(config=client_1_cfg)
|
|
client_1.power_on()
|
|
network.connect(
|
|
endpoint_a=client_1.network_interface[1],
|
|
endpoint_b=switch_1.network_interface[1],
|
|
)
|
|
|
|
# 1.4: Create and connect servers
|
|
server_1_cfg = {"type": "server",
|
|
"hostname":"server_1",
|
|
"ip_address": "10.0.2.2",
|
|
"subnet_mask":"255.255.255.0",
|
|
"default_gateway":"10.0.2.1",
|
|
"start_up_duration": 0}
|
|
|
|
|
|
server_1: Server = Server.from_config(config=server_1_cfg)
|
|
server_1.power_on()
|
|
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1])
|
|
server_2_cfg = {"type": "server",
|
|
"hostname":"server_2",
|
|
"ip_address": "10.0.2.3",
|
|
"subnet_mask":"255.255.255.0",
|
|
"default_gateway":"10.0.2.1",
|
|
"start_up_duration": 0}
|
|
|
|
|
|
server_2: Server = Server.from_config(config=server_2_cfg)
|
|
server_2.power_on()
|
|
network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2])
|
|
|
|
# 2: Configure base ACL
|
|
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22)
|
|
router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
|
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1)
|
|
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3)
|
|
|
|
# 3: Install server software
|
|
server_1.software_manager.install(DNSServer)
|
|
dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa
|
|
dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address)
|
|
server_2.software_manager.install(WebServer)
|
|
|
|
# 3.1: Ensure that the dns clients are configured correctly
|
|
client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
|
|
server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
|
|
|
|
# 4: Check that client came pre-installed with web browser and dns client
|
|
assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser)
|
|
assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient)
|
|
|
|
# 4.1: Create a file on the computer
|
|
client_1.file_system.create_file("cat.png", 300, folder_name="downloads")
|
|
|
|
# 5: Assert that the simulation starts off in the state that we expect
|
|
assert len(sim.network.nodes) == 6
|
|
assert len(sim.network.links) == 5
|
|
|
|
# 5.1: Assert the router is correctly configured
|
|
r = sim.network.router_nodes[0]
|
|
for i, acl_rule in enumerate(r.acl.acl):
|
|
if i == 1:
|
|
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"]
|
|
elif i == 3:
|
|
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"]
|
|
elif i == 22:
|
|
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"]
|
|
elif i == 23:
|
|
assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"]
|
|
elif i == 24:
|
|
...
|
|
else:
|
|
assert acl_rule is None
|
|
|
|
# 5.2: Assert the client is correctly configured
|
|
c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0]
|
|
assert c.software_manager.software.get("WebBrowser") is not None
|
|
assert c.software_manager.software.get("DNSClient") is not None
|
|
assert str(c.network_interface[1].ip_address) == "10.0.1.2"
|
|
|
|
# 5.3: Assert that server_1 is correctly configured
|
|
s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0]
|
|
assert str(s1.network_interface[1].ip_address) == "10.0.2.2"
|
|
assert s1.software_manager.software.get("DNSServer") is not None
|
|
|
|
# 5.4: Assert that server_2 is correctly configured
|
|
s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0]
|
|
assert str(s2.network_interface[1].ip_address) == "10.0.2.3"
|
|
assert s2.software_manager.software.get("WebServer") is not None
|
|
|
|
# 6: Return the simulation
|
|
return sim
|
|
|
|
|
|
@pytest.fixture
|
|
def game_and_agent():
|
|
"""Create a game with a simple agent that can be controlled by the tests."""
|
|
game = PrimaiteGame()
|
|
sim = game.simulation
|
|
install_stuff_to_sim(sim)
|
|
|
|
config = {
|
|
"type": "ControlledAgent",
|
|
"ref": "test_agent",
|
|
"team": "BLUE",
|
|
}
|
|
|
|
test_agent = ControlledAgent(config=config)
|
|
|
|
game.agents["test_agent"] = test_agent
|
|
|
|
game.setup_reward_sharing()
|
|
|
|
return (game, test_agent)
|