Files
PrimAITE/tests/conftest.py

514 lines
18 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Optional, Tuple
import pytest
2024-05-02 17:00:29 +01:00
import yaml
2024-07-12 11:23:41 +01:00
from ray import init as rayinit
2024-05-02 17:00:29 +01:00
from primaite import getLogger, PRIMAITE_PATHS
2024-02-06 15:05:44 +00:00
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
2024-02-06 15:05:44 +00:00
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.networks import arcd_uc2_network
2024-02-06 15:05:44 +00:00
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
2024-02-06 15:05:44 +00:00
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.core.sys_log import SysLog
2024-02-06 15:05:44 +00:00
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import Service
2024-02-06 15:05:44 +00:00
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from tests import TEST_ASSETS_ROOT
rayinit()
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
class DummyService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
kwargs["name"] = "DummyService"
kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
pass
2024-07-08 15:26:30 +01:00
class DummyApplication(Application, identifier="DummyApplication"):
"""Test Application class"""
def __init__(self, **kwargs):
2024-07-08 15:26:30 +01:00
kwargs["name"] = "DummyApplication"
kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def describe_state(self) -> Dict:
return super().describe_state()
@pytest.fixture(scope="function")
def uc2_network() -> Network:
with open(PRIMAITE_PATHS.user_config_path / "example_config" / "data_manipulation.yaml") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
return game.simulation.network
@pytest.fixture(scope="function")
def service(file_system) -> DummyService:
return DummyService(
name="DummyService", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service")
)
@pytest.fixture(scope="function")
def service_class():
return DummyService
@pytest.fixture(scope="function")
2024-07-08 15:26:30 +01:00
def application(file_system) -> DummyApplication:
return DummyApplication(
name="DummyApplication",
port=PORT_LOOKUP["ARP"],
file_system=file_system,
sys_log=SysLog(hostname="dummy_application"),
)
@pytest.fixture(scope="function")
def application_class():
2024-07-08 15:26:30 +01:00
return DummyApplication
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
2024-02-08 15:27:02 +00:00
computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
computer.power_on()
return computer.file_system
@pytest.fixture(scope="function")
def client_server() -> Tuple[Computer, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
# Connect Computer and Server
network.connect(computer.network_interface[1], server.network_interface[1])
# Should be linked
assert next(iter(network.links.values())).is_up
return computer, server
@pytest.fixture(scope="function")
def client_switch_server() -> Tuple[Computer, Switch, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
switch = Switch(hostname="switch", start_up_duration=0)
switch.power_on()
network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1])
network.connect(endpoint_a=server.network_interface[1], endpoint_b=switch.network_interface[2])
assert all(link.is_up for link in network.links.values())
return computer, switch, server
@pytest.fixture(scope="function")
def example_network() -> Network:
"""
Create the network used for testing.
Should only contain the nodes and links.
This would act as the base network and services and applications are installed in the relevant test file,
-------------- --------------
| client_1 |----- ----| server_1 |
-------------- | -------------- -------------- -------------- | --------------
------| switch_2 |------| router_1 |------| switch_1 |------
-------------- | -------------- -------------- -------------- | --------------
| client_2 |---- ----| server_2 |
-------------- --------------
"""
network = Network()
# Router 1
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
client_2.power_on()
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
# DServer 2
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_2.power_on()
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
router_1.acl.add_rule(action=ACLAction.PERMIT, position=1)
assert all(link.is_up for link in network.links.values())
return network
2024-02-06 15:05:44 +00:00
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
2024-02-06 15:05:44 +00:00
"""Agent that can be controlled by the tests."""
2024-11-27 15:29:51 +00:00
config: "ControlledAgent.ConfigSchema"
most_recent_action: Optional[Tuple[str, Dict]] = None
2024-11-27 15:29:51 +00:00
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Abstract Agent used in tests."""
agent_name: str = "Controlled_Agent"
2024-02-06 15:05:44 +00:00
2024-03-04 10:43:38 +00:00
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
2024-02-06 15:05:44 +00:00
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
def store_action(self, action: Tuple[str, Dict]):
"""Store the most recent action."""
self.most_recent_action = action
def install_stuff_to_sim(sim: Simulation):
"""Create a simulation with a computer, two servers, two switches, and a router."""
# 0: Pull out the network
network = sim.network
# 1: Set up network hardware
# 1.1: Configure the router
2024-02-08 16:15:57 +00:00
router = Router(hostname="router", num_ports=3, start_up_duration=0)
2024-02-06 15:05:44 +00:00
router.power_on()
router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0")
router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0")
# 1.2: Create and connect switches
2024-02-08 16:15:57 +00:00
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
2024-02-06 15:05:44 +00:00
switch_1.power_on()
2024-02-08 16:15:57 +00:00
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
2024-02-06 15:05:44 +00:00
router.enable_port(1)
2024-02-08 16:15:57 +00:00
switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0)
2024-02-06 15:05:44 +00:00
switch_2.power_on()
2024-02-08 16:15:57 +00:00
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])
2024-02-06 15:05:44 +00:00
router.enable_port(2)
# 1.3: Create and connect computer
client_1 = Computer(
hostname="client_1",
ip_address="10.0.1.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.1.1",
2024-02-08 16:15:57 +00:00
start_up_duration=0,
2024-02-06 15:05:44 +00:00
)
client_1.power_on()
network.connect(
2024-02-08 16:15:57 +00:00
endpoint_a=client_1.network_interface[1],
endpoint_b=switch_1.network_interface[1],
2024-02-06 15:05:44 +00:00
)
# 1.4: Create and connect servers
server_1 = Server(
hostname="server_1",
ip_address="10.0.2.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
2024-02-08 16:15:57 +00:00
start_up_duration=0,
2024-02-06 15:05:44 +00:00
)
server_1.power_on()
2024-02-08 16:15:57 +00:00
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1])
2024-02-06 15:05:44 +00:00
server_2 = Server(
hostname="server_2",
ip_address="10.0.2.3",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
2024-02-08 16:15:57 +00:00
start_up_duration=0,
2024-02-06 15:05:44 +00:00
)
server_2.power_on()
2024-02-08 16:15:57 +00:00
network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2])
2024-02-06 15:05:44 +00:00
# 2: Configure base ACL
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3)
2024-02-06 15:05:44 +00:00
# 3: Install server software
server_1.software_manager.install(DNSServer)
dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa
2024-02-08 16:15:57 +00:00
dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address)
2024-02-06 15:05:44 +00:00
server_2.software_manager.install(WebServer)
# 3.1: Ensure that the dns clients are configured correctly
2024-02-08 16:15:57 +00:00
client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
2024-02-06 15:05:44 +00:00
# 4: Check that client came pre-installed with web browser and dns client
assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser)
assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient)
# 4.1: Create a file on the computer
client_1.file_system.create_file("cat.png", 300, folder_name="downloads")
# 5: Assert that the simulation starts off in the state that we expect
assert len(sim.network.nodes) == 6
assert len(sim.network.links) == 5
# 5.1: Assert the router is correctly configured
r = sim.network.router_nodes[0]
2024-02-06 15:05:44 +00:00
for i, acl_rule in enumerate(r.acl.acl):
if i == 1:
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"]
2024-02-06 15:05:44 +00:00
elif i == 3:
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"]
2024-02-06 15:05:44 +00:00
elif i == 22:
assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"]
2024-02-06 15:05:44 +00:00
elif i == 23:
assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"]
2024-02-06 15:05:44 +00:00
elif i == 24:
...
else:
assert acl_rule is None
# 5.2: Assert the client is correctly configured
c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0]
assert c.software_manager.software.get("WebBrowser") is not None
assert c.software_manager.software.get("DNSClient") is not None
2024-02-08 16:15:57 +00:00
assert str(c.network_interface[1].ip_address) == "10.0.1.2"
2024-02-06 15:05:44 +00:00
# 5.3: Assert that server_1 is correctly configured
s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0]
2024-02-08 16:15:57 +00:00
assert str(s1.network_interface[1].ip_address) == "10.0.2.2"
2024-02-06 15:05:44 +00:00
assert s1.software_manager.software.get("DNSServer") is not None
# 5.4: Assert that server_2 is correctly configured
s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0]
2024-02-08 16:15:57 +00:00
assert str(s2.network_interface[1].ip_address) == "10.0.2.3"
2024-02-06 15:05:44 +00:00
assert s2.software_manager.software.get("WebServer") is not None
# 6: Return the simulation
return sim
@pytest.fixture
def game_and_agent():
"""Create a game with a simple agent that can be controlled by the tests."""
game = PrimaiteGame()
sim = game.simulation
install_stuff_to_sim(sim)
actions = [
{"type": "DONOTHING"},
{"type": "NODE_SERVICE_SCAN"},
{"type": "NODE_SERVICE_STOP"},
{"type": "NODE_SERVICE_START"},
{"type": "NODE_SERVICE_PAUSE"},
{"type": "NODE_SERVICE_RESUME"},
{"type": "NODE_SERVICE_RESTART"},
{"type": "NODE_SERVICE_DISABLE"},
{"type": "NODE_SERVICE_ENABLE"},
2024-03-26 10:51:33 +00:00
{"type": "NODE_SERVICE_FIX"},
2024-02-06 15:05:44 +00:00
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_APPLICATION_SCAN"},
{"type": "NODE_APPLICATION_CLOSE"},
{"type": "NODE_APPLICATION_FIX"},
{"type": "NODE_APPLICATION_INSTALL"},
{"type": "NODE_APPLICATION_REMOVE"},
2024-05-20 13:10:21 +01:00
{"type": "NODE_FILE_CREATE"},
2024-02-06 15:05:44 +00:00
{"type": "NODE_FILE_SCAN"},
{"type": "NODE_FILE_CHECKHASH"},
{"type": "NODE_FILE_DELETE"},
{"type": "NODE_FILE_REPAIR"},
{"type": "NODE_FILE_RESTORE"},
{"type": "NODE_FILE_CORRUPT"},
2024-05-20 13:10:21 +01:00
{"type": "NODE_FILE_ACCESS"},
{"type": "NODE_FOLDER_CREATE"},
2024-02-06 15:05:44 +00:00
{"type": "NODE_FOLDER_SCAN"},
{"type": "NODE_FOLDER_CHECKHASH"},
{"type": "NODE_FOLDER_REPAIR"},
{"type": "NODE_FOLDER_RESTORE"},
{"type": "NODE_OS_SCAN"},
{"type": "NODE_SHUTDOWN"},
{"type": "NODE_STARTUP"},
{"type": "NODE_RESET"},
{"type": "ROUTER_ACL_ADDRULE"},
{"type": "ROUTER_ACL_REMOVERULE"},
{"type": "HOST_NIC_ENABLE"},
{"type": "HOST_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
{"type": "CONFIGURE_C2_BEACON"},
{"type": "C2_SERVER_RANSOMWARE_LAUNCH"},
{"type": "C2_SERVER_RANSOMWARE_CONFIGURE"},
{"type": "C2_SERVER_TERMINAL_COMMAND"},
{"type": "C2_SERVER_DATA_EXFILTRATE"},
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
{"type": "SSH_TO_REMOTE"},
{"type": "SESSIONS_REMOTE_LOGOFF"},
{"type": "NODE_SEND_REMOTE_COMMAND"},
2024-02-06 15:05:44 +00:00
]
action_space = ActionManager(
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{
"node_name": "client_1",
"applications": [
{"application_name": "WebBrowser"},
{"application_name": "DoSBot"},
{"application_name": "C2Server"},
],
2024-02-06 15:05:44 +00:00
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
{
"node_name": "server_1",
"services": [{"service_name": "DNSServer"}],
"applications": [{"application_name": "C2Beacon"}],
},
2024-02-06 15:05:44 +00:00
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
{"node_name": "router"},
2024-02-06 15:05:44 +00:00
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=3,
2024-02-06 15:05:44 +00:00
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
2024-04-15 11:50:08 +01:00
ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
2024-02-06 15:05:44 +00:00
act_map={},
)
observation_space = ObservationManager(NestedObservation(components={}))
2024-02-06 15:05:44 +00:00
reward_function = RewardFunction()
config = {
2024-12-17 12:50:14 +00:00
"agent_name": "test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
}
test_agent = ControlledAgent.from_config(config=config)
2024-02-06 15:05:44 +00:00
game.agents["test_agent"] = test_agent
2024-02-06 15:05:44 +00:00
2024-03-14 14:33:04 +00:00
game.setup_reward_sharing()
2024-02-06 15:05:44 +00:00
return (game, test_agent)