Files
PrimAITE/tests/conftest.py
2024-02-27 13:30:16 +00:00

515 lines
18 KiB
Python

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import pytest
import yaml
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
class TestService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
kwargs["name"] = "TestService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
pass
class TestApplication(Application):
"""Test Application class"""
def __init__(self, **kwargs):
kwargs["name"] = "TestApplication"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
return super().describe_state()
@pytest.fixture(scope="function")
def uc2_network() -> Network:
return arcd_uc2_network()
@pytest.fixture(scope="function")
def service(file_system) -> TestService:
return TestService(
name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
)
@pytest.fixture(scope="function")
def service_class():
return TestService
@pytest.fixture(scope="function")
def application(file_system) -> TestApplication:
return TestApplication(
name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application")
)
@pytest.fixture(scope="function")
def application_class():
return TestApplication
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
computer.power_on()
return computer.file_system
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
"""Create a temporary PrimaiteSession object from a config file."""
config_path = Path(config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return super().from_config(cfg=config)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
pass
@pytest.fixture
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
"""Create a temporary PrimaiteSession object."""
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)
@pytest.fixture(scope="function")
def client_server() -> Tuple[Computer, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
# Connect Computer and Server
network.connect(computer.network_interface[1], server.network_interface[1])
# Should be linked
assert next(iter(network.links.values())).is_up
return computer, server
@pytest.fixture(scope="function")
def client_switch_server() -> Tuple[Computer, Switch, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
switch = Switch(hostname="switch", start_up_duration=0)
switch.power_on()
network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1])
network.connect(endpoint_a=server.network_interface[1], endpoint_b=switch.network_interface[2])
assert all(link.is_up for link in network.links.values())
return computer, switch, server
@pytest.fixture(scope="function")
def example_network() -> Network:
"""
Create the network used for testing.
Should only contain the nodes and links.
This would act as the base network and services and applications are installed in the relevant test file,
-------------- --------------
| client_1 |----- ----| server_1 |
-------------- | -------------- -------------- -------------- | --------------
------| switch_2 |------| router_1 |------| switch_1 |------
-------------- | -------------- -------------- -------------- | --------------
| client_2 |---- ----| server_2 |
-------------- --------------
"""
network = Network()
# Router 1
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
client_2.power_on()
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
# DServer 2
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_2.power_on()
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
assert all(link.is_up for link in network.links.values())
return network
class ControlledAgent(AbstractAgent):
"""Agent that can be controlled by the tests."""
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
def store_action(self, action: Tuple[str, Dict]):
"""Store the most recent action."""
self.most_recent_action = action
def install_stuff_to_sim(sim: Simulation):
"""Create a simulation with a computer, two servers, two switches, and a router."""
# 0: Pull out the network
network = sim.network
# 1: Set up network hardware
# 1.1: Configure the router
router = Router(hostname="router", num_ports=3, start_up_duration=0)
router.power_on()
router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0")
router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0")
# 1.2: Create and connect switches
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
router.enable_port(1)
switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0)
switch_2.power_on()
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])
router.enable_port(2)
# 1.3: Create and connect computer
client_1 = Computer(
hostname="client_1",
ip_address="10.0.1.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.1.1",
start_up_duration=0,
)
client_1.power_on()
network.connect(
endpoint_a=client_1.network_interface[1],
endpoint_b=switch_1.network_interface[1],
)
# 1.4: Create and connect servers
server_1 = Server(
hostname="server_1",
ip_address="10.0.2.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_1.power_on()
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1])
server_2 = Server(
hostname="server_2",
ip_address="10.0.2.3",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_2.power_on()
network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2])
# 2: Configure base ACL
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
# 3: Install server software
server_1.software_manager.install(DNSServer)
dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa
dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address)
server_2.software_manager.install(WebServer)
# 3.1: Ensure that the dns clients are configured correctly
client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
# 4: Check that client came pre-installed with web browser and dns client
assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser)
assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient)
# 4.1: Create a file on the computer
client_1.file_system.create_file("cat.png", 300, folder_name="downloads")
# 5: Assert that the simulation starts off in the state that we expect
assert len(sim.network.nodes) == 6
assert len(sim.network.links) == 5
# 5.1: Assert the router is correctly configured
r = sim.network.routers[0]
for i, acl_rule in enumerate(r.acl.acl):
if i == 1:
assert acl_rule.src_port == acl_rule.dst_port == Port.DNS
elif i == 3:
assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP
elif i == 22:
assert acl_rule.src_port == acl_rule.dst_port == Port.ARP
elif i == 23:
assert acl_rule.protocol == IPProtocol.ICMP
elif i == 24:
...
else:
assert acl_rule is None
# 5.2: Assert the client is correctly configured
c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0]
assert c.software_manager.software.get("WebBrowser") is not None
assert c.software_manager.software.get("DNSClient") is not None
assert str(c.network_interface[1].ip_address) == "10.0.1.2"
# 5.3: Assert that server_1 is correctly configured
s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0]
assert str(s1.network_interface[1].ip_address) == "10.0.2.2"
assert s1.software_manager.software.get("DNSServer") is not None
# 5.4: Assert that server_2 is correctly configured
s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0]
assert str(s2.network_interface[1].ip_address) == "10.0.2.3"
assert s2.software_manager.software.get("WebServer") is not None
# 6: Return the simulation
return sim
@pytest.fixture
def game_and_agent():
"""Create a game with a simple agent that can be controlled by the tests."""
game = PrimaiteGame()
sim = game.simulation
install_stuff_to_sim(sim)
actions = [
{"type": "DONOTHING"},
{"type": "NODE_SERVICE_SCAN"},
{"type": "NODE_SERVICE_STOP"},
{"type": "NODE_SERVICE_START"},
{"type": "NODE_SERVICE_PAUSE"},
{"type": "NODE_SERVICE_RESUME"},
{"type": "NODE_SERVICE_RESTART"},
{"type": "NODE_SERVICE_DISABLE"},
{"type": "NODE_SERVICE_ENABLE"},
{"type": "NODE_SERVICE_PATCH"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_SCAN"},
{"type": "NODE_FILE_CHECKHASH"},
{"type": "NODE_FILE_DELETE"},
{"type": "NODE_FILE_REPAIR"},
{"type": "NODE_FILE_RESTORE"},
{"type": "NODE_FILE_CORRUPT"},
{"type": "NODE_FOLDER_SCAN"},
{"type": "NODE_FOLDER_CHECKHASH"},
{"type": "NODE_FOLDER_REPAIR"},
{"type": "NODE_FOLDER_RESTORE"},
{"type": "NODE_OS_SCAN"},
{"type": "NODE_SHUTDOWN"},
{"type": "NODE_STARTUP"},
{"type": "NODE_RESET"},
{"type": "NETWORK_ACL_ADDRULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_NIC_ENABLE"},
{"type": "NETWORK_NIC_DISABLE"},
]
action_space = ActionManager(
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{
"node_name": "client_1",
"applications": [{"application_name": "WebBrowser"}],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=2,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(ICSObservation())
reward_function = RewardFunction()
test_agent = ControlledAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
game.agents.append(test_agent)
return (game, test_agent)