#2248 - synced wth dev

This commit is contained in:
Chris McCarthy
2024-02-08 16:15:57 +00:00
20 changed files with 607 additions and 804 deletions

View File

@@ -35,238 +35,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import SoftwareHealthState
class ControlledAgent(AbstractAgent):
"""Agent that can be controlled by the tests."""
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
def store_action(self, action: Tuple[str, Dict]):
"""Store the most recent action."""
self.most_recent_action = action
def install_stuff_to_sim(sim: Simulation):
"""Create a simulation with a computer, two servers, two switches, and a router."""
# 0: Pull out the network
network = sim.network
# 1: Set up network hardware
# 1.1: Configure the router
router = Router(hostname="router", num_ports=3, start_up_duration=0)
router.power_on()
router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0")
router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0")
# 1.2: Create and connect switches
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
router.enable_port(1)
switch_2 = Switch(
hostname="switch_2",
num_ports=6,
start_up_duration=0,
)
switch_2.power_on()
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])
router.enable_port(2)
# 1.3: Create and connect computer
client_1 = Computer(
hostname="client_1",
ip_address="10.0.1.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.1.1",
start_up_duration=0,
)
client_1.power_on()
network.connect(
endpoint_a=client_1.network_interface[1],
endpoint_b=switch_1.network_interface[1],
)
# 1.4: Create and connect servers
server_1 = Server(
hostname="server_1",
ip_address="10.0.2.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_1.power_on()
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1])
server_2 = Server(
hostname="server_2",
ip_address="10.0.2.3",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_2.power_on()
network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2])
# 2: Configure base ACL
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
# 3: Install server software
server_1.software_manager.install(DNSServer)
dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa
dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address)
server_2.software_manager.install(WebServer)
# 3.1: Ensure that the dns clients are configured correctly
client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address
# 4: Check that client came pre-installed with web browser and dns client
assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser)
assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient)
# 4.1: Create a file on the computer
client_1.file_system.create_file("cat.png", 300, folder_name="downloads")
# 5: Assert that the simulation starts off in the state that we expect
assert len(sim.network.nodes) == 6
assert len(sim.network.links) == 5
# 5.1: Assert the router is correctly configured
r = sim.network.routers[0]
for i, acl_rule in enumerate(r.acl.acl):
if i == 1:
assert acl_rule.src_port == acl_rule.dst_port == Port.DNS
elif i == 3:
assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP
elif i == 22:
assert acl_rule.src_port == acl_rule.dst_port == Port.ARP
elif i == 23:
assert acl_rule.protocol == IPProtocol.ICMP
elif i == 24:
...
else:
assert acl_rule is None
# 5.2: Assert the client is correctly configured
c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0]
assert c.software_manager.software.get("WebBrowser") is not None
assert c.software_manager.software.get("DNSClient") is not None
assert str(c.network_interface[1].ip_address) == "10.0.1.2"
# 5.3: Assert that server_1 is correctly configured
s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0]
assert str(s1.network_interface[1].ip_address) == "10.0.2.2"
assert s1.software_manager.software.get("DNSServer") is not None
# 5.4: Assert that server_2 is correctly configured
s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0]
assert str(s2.network_interface[1].ip_address) == "10.0.2.3"
assert s2.software_manager.software.get("WebServer") is not None
# 6: Return the simulation
return sim
@pytest.fixture
def game_and_agent():
"""Create a game with a simple agent that can be controlled by the tests."""
game = PrimaiteGame()
sim = game.simulation
install_stuff_to_sim(sim)
actions = [
{"type": "DONOTHING"},
{"type": "NODE_SERVICE_SCAN"},
{"type": "NODE_SERVICE_STOP"},
{"type": "NODE_SERVICE_START"},
{"type": "NODE_SERVICE_PAUSE"},
{"type": "NODE_SERVICE_RESUME"},
{"type": "NODE_SERVICE_RESTART"},
{"type": "NODE_SERVICE_DISABLE"},
{"type": "NODE_SERVICE_ENABLE"},
{"type": "NODE_SERVICE_PATCH"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_SCAN"},
{"type": "NODE_FILE_CHECKHASH"},
{"type": "NODE_FILE_DELETE"},
{"type": "NODE_FILE_REPAIR"},
{"type": "NODE_FILE_RESTORE"},
{"type": "NODE_FILE_CORRUPT"},
{"type": "NODE_FOLDER_SCAN"},
{"type": "NODE_FOLDER_CHECKHASH"},
{"type": "NODE_FOLDER_REPAIR"},
{"type": "NODE_FOLDER_RESTORE"},
{"type": "NODE_OS_SCAN"},
{"type": "NODE_SHUTDOWN"},
{"type": "NODE_STARTUP"},
{"type": "NODE_RESET"},
{"type": "NETWORK_ACL_ADDRULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_NIC_ENABLE"},
{"type": "NETWORK_NIC_DISABLE"},
]
action_space = ActionManager(
game=game,
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{
"node_name": "client_1",
"applications": [{"application_name": "WebBrowser"}],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=2,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(ICSObservation())
reward_function = RewardFunction()
test_agent = ControlledAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
game.agents.append(test_agent)
return (game, test_agent)
# def test_test(game_and_agent:Tuple[PrimaiteGame, ProxyAgent]):
# game, agent = game_and_agent
def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent