diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 9b6b63cc..d65cd8d0 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -549,7 +549,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): class ActionManager: """Class which manages the action space for an agent.""" - __act_class_identifiers: Dict[str, type] = { + _act_class_identifiers: Dict[str, type] = { "DONOTHING": DoNothingAction, "NODE_SERVICE_SCAN": NodeServiceScanAction, "NODE_SERVICE_STOP": NodeServiceStopAction, @@ -584,7 +584,7 @@ class ActionManager: def __init__( self, game: "PrimaiteGame", # reference to game for information lookup - actions: List[str], # stores list of actions available to agent + actions: List[Dict], # stores list of actions available to agent nodes: List[Dict], # extra configuration for each node max_folders_per_node: int = 2, # allows calculating shape max_files_per_folder: int = 2, # allows calculating shape @@ -601,8 +601,9 @@ class ActionManager: :param game: Reference to the game to which the agent belongs. :type game: PrimaiteGame - :param actions: List of action types which should be made available to the agent. - :type actions: List[str] + :param actions: List of action specs which should be made available to the agent. The keys of each spec are: + 'type' and 'options' for passing any options to the action class's init method + :type actions: List[dict] :param nodes: Extra configuration for each node. :type nodes: Dict :param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape. @@ -728,7 +729,7 @@ class ActionManager: # and `options` is an optional dict of options to pass to the init method of the action class act_type = act_spec.get("type") act_options = act_spec.get("options", {}) - self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options) + self.actions[act_type] = self._act_class_identifiers[act_type](self, **global_action_args, **act_options) self.action_map: Dict[int, Tuple[str, Dict]] = {} """ diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py new file mode 100644 index 00000000..37a680c8 --- /dev/null +++ b/tests/integration_tests/game_layer/test_actions.py @@ -0,0 +1,185 @@ +# Plan for creating integration tests for the actions: +# I need to test that the requests coming out of the actions have the intended effect on the simulation. +# I can do this by creating a simulation, and then running the action on the simulation, and then checking +# the state of the simulation. + +# Steps for creating the integration tests: +# 1. Create a fixture which creates a simulation. +# 2. Create a fixture which creates a game, including a simple agent with some actions. +# 3. Get the agent to perform an action of my choosing. +# 4. Check that the simulation has changed in the way that I expect. +# 5. Repeat for all actions. + +import pytest + +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.interface import ProxyAgent +from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.rewards import RewardFunction +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.hardware.nodes.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.web_server.web_server import WebServer + + +def install_stuff_to_sim(sim: Simulation): + """Create a simulation with a three computers, two switches, and a router.""" + + # 0: Pull out the network + network = sim.network + + # 1: Set up network hardware + # 1.1: Configure the router + router = Router(hostname="router", num_ports=3) + router.power_on() + router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") + router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") + + # 1.2: Create and connect switches + switch_1 = Switch(hostname="switch_1", num_ports=6) + switch_1.power_on() + network.connect(endpoint_a=router.ethernet_ports[1], endpoint_b=switch_1.switch_ports[6]) + router.enable_port(1) + switch_2 = Switch(hostname="switch_2", num_ports=6) + switch_2.power_on() + network.connect(endpoint_a=router.ethernet_ports[2], endpoint_b=switch_2.switch_ports[6]) + router.enable_port(2) + + # 1.3: Create and connect computer + client_1 = Computer( + hostname="client_1", + ip_address="10.0.1.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.1.1", + operating_state=NodeOperatingState.ON, + ) + client_1.power_on() + network.connect( + endpoint_a=client_1.ethernet_port[1], + endpoint_b=switch_1.switch_ports[1], + ) + + # 1.4: Create and connect servers + server_1 = Server( + hostname="server_1", + ip_address="10.0.2.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + operating_state=NodeOperatingState.ON, + ) + server_1.power_on() + network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_2.switch_ports[1]) + + server_2 = Server( + hostname="server_2", + ip_address="10.0.2.3", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + operating_state=NodeOperatingState.ON, + ) + server_2.power_on() + network.connect(endpoint_a=server_2.ethernet_port[1], endpoint_b=switch_2.switch_ports[2]) + + # 2: Configure base ACL + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + + # 3: Install server software + server_1.software_manager.install(DNSServer) + dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa + dns_service.dns_register("example.com", server_2.ip_address) + server_2.software_manager.install(WebServer) + + # 4: Check that client came pre-installed with web browser and dns client + assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser) + assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient) + + # 5: Return the simulation + return sim + + +@pytest.fixture +def game(): + """Create a game with a simple agent that can be controlled by the tests.""" + game = PrimaiteGame() + sim = game.simulation + install_stuff_to_sim(sim) + + actions = [ + {"type": "DONOTHING"}, + {"type": "NODE_SERVICE_SCAN"}, + {"type": "NODE_SERVICE_STOP"}, + {"type": "NODE_SERVICE_START"}, + {"type": "NODE_SERVICE_PAUSE"}, + {"type": "NODE_SERVICE_RESUME"}, + {"type": "NODE_SERVICE_RESTART"}, + {"type": "NODE_SERVICE_DISABLE"}, + {"type": "NODE_SERVICE_ENABLE"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_SCAN"}, + {"type": "NODE_FILE_CHECKHASH"}, + {"type": "NODE_FILE_DELETE"}, + {"type": "NODE_FILE_REPAIR"}, + {"type": "NODE_FILE_RESTORE"}, + {"type": "NODE_FILE_CORRUPT"}, + {"type": "NODE_FOLDER_SCAN"}, + {"type": "NODE_FOLDER_CHECKHASH"}, + {"type": "NODE_FOLDER_REPAIR"}, + {"type": "NODE_FOLDER_RESTORE"}, + {"type": "NODE_OS_SCAN"}, + {"type": "NODE_SHUTDOWN"}, + {"type": "NODE_STARTUP"}, + {"type": "NODE_RESET"}, + {"type": "NETWORK_ACL_ADDRULE", "options": {"target_router_hostname": "router"}}, + {"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}}, + {"type": "NETWORK_NIC_ENABLE"}, + {"type": "NETWORK_NIC_DISABLE"}, + ] + + action_space = ActionManager( + game=game, + actions=actions, # ALL POSSIBLE ACTIONS + nodes=[ + {"node_name": "client_1", "applications": [{"application_name": "WebBrowser"}]}, + {"node_name": "server_1", "services": [{"service_name": "DNSServer"}]}, + {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], + act_map={}, + ) + observation_space = None + reward_function = None + + test_agent = ProxyAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + + game.agents.append(test_agent) + + return game, test_agent + + +def test_test(game): + assert True diff --git a/tests/unit_tests/_primaite/_game/__init__.py b/tests/unit_tests/_primaite/_game/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_game/_agent/__init__.py b/tests/unit_tests/_primaite/_game/_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_game/_agent/test_actions.py b/tests/unit_tests/_primaite/_game/_agent/test_actions.py new file mode 100644 index 00000000..9b641fe2 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_actions.py @@ -0,0 +1,90 @@ +from unittest.mock import Mock + +import pytest + +from primaite.game.agent.actions import ( + ActionManager, + DoNothingAction, + NodeServiceDisableAction, + NodeServiceEnableAction, + NodeServicePauseAction, + NodeServiceRestartAction, + NodeServiceResumeAction, + NodeServiceScanAction, + NodeServiceStartAction, + NodeServiceStopAction, +) + + +def test_do_nothing_action_form_request(): + """Test that the DoNothingAction can form a request and that it is correct.""" + manager = Mock() + + action = DoNothingAction(manager=manager) + + request = action.form_request() + + assert request == ["do_nothing"] + + +@pytest.mark.parametrize( + "action_class, action_verb", + [ + (NodeServiceScanAction, "scan"), + (NodeServiceStopAction, "stop"), + (NodeServiceStartAction, "start"), + (NodeServicePauseAction, "pause"), + (NodeServiceResumeAction, "resume"), + (NodeServiceRestartAction, "restart"), + (NodeServiceDisableAction, "disable"), + (NodeServiceEnableAction, "enable"), + ], +) # flake8: noqa +@pytest.mark.parametrize( + "node_name, service_name, expect_to_do_nothing", + [ + ("pc_1", "chrome", False), + (None, "chrome", True), + ("pc_1", None, True), + (None, None, True), + ], +) # flake8: noqa +def test_service_action_form_request(node_name, service_name, expect_to_do_nothing, action_class, action_verb): + """Test that the ServiceScanAction can form a request and that it is correct.""" + manager: ActionManager = Mock() + manager.get_node_name_by_idx.return_value = node_name + manager.get_service_name_by_idx.return_value = service_name + + action = action_class(manager=manager, num_nodes=1, num_services=1) + + request = action.form_request(node_id=0, service_id=0) + + if expect_to_do_nothing: + assert request == ["do_nothing"] + else: + assert request == ["network", "node", node_name, "services", service_name, action_verb] + + +@pytest.mark.parametrize( + "node_name, service_name, expect_to_do_nothing", + [ + ("pc_1", "chrome", False), + (None, "chrome", True), + ("pc_1", None, True), + (None, None, True), + ], +) # flake8: noqa +def test_service_scan_form_request(node_name, service_name, expect_to_do_nothing): + """Test that the ServiceScanAction can form a request and that it is correct.""" + manager: ActionManager = Mock() + manager.get_node_name_by_idx.return_value = node_name + manager.get_service_name_by_idx.return_value = service_name + + action = NodeServiceScanAction(manager=manager, num_nodes=1, num_services=1) + + request = action.form_request(node_id=0, service_id=0) + + if expect_to_do_nothing: + assert request == ["do_nothing"] + else: + assert request == ["network", "node", node_name, "services", service_name, "scan"]