diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 3cea2f29..b68861e1 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,5 +1,5 @@ training_config: - rl_framework: RLLIB_single_agent + rl_framework: SB3 rl_algorithm: PPO seed: 333 n_learn_episodes: 1 @@ -36,22 +36,16 @@ agents: action_space: action_list: - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com - + - type: NODE_APPLICATION_EXECUTE options: nodes: - node_ref: client_2 + applications: + - application_ref: client_2_web_browser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 1 reward_function: reward_components: @@ -549,19 +543,19 @@ simulation: ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: - 0: + 18: action: PERMIT src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER - 1: + 19: action: PERMIT src_port: DNS dst_port: DNS - 2: + 20: action: PERMIT src_port: FTP dst_port: FTP - 3: + 21: action: PERMIT src_port: HTTP dst_port: HTTP @@ -679,10 +673,14 @@ simulation: applications: - ref: client_2_web_browser type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_2_dns_client type: DNSClient + + links: - ref: router_1___switch_1 endpoint_a_ref: router_1 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ae60bbc1..48615ca6 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -316,6 +316,10 @@ class PrimaiteGame: port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), ) + elif application_type == "WebBrowser": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.target_url = opt.get("target_url") if "nics" in node_cfg: for nic_num, nic_cfg in node_cfg["nics"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) @@ -377,7 +381,6 @@ class PrimaiteGame: action_space_cfg["options"]["application_uuids"].append(node_application_uuids) else: action_space_cfg["options"]["application_uuids"].append([]) - # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. # However, it's not possible to specify the node uuids directly in the config, as they are generated diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index db24db60..a5fdade9 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -37,7 +37,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} - + print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index ea767b54..446e5649 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -157,6 +157,8 @@ def arcd_uc2_network() -> Network: operating_state=NodeOperatingState.ON, ) client_2.power_on() + web_browser = client_2.software_manager["WebBrowser"] + web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller diff --git a/src/primaite/simulator/network/protocols/http.py b/src/primaite/simulator/network/protocols/http.py index 2dba2614..b88916a9 100644 --- a/src/primaite/simulator/network/protocols/http.py +++ b/src/primaite/simulator/network/protocols/http.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, IntEnum from primaite.simulator.network.protocols.packet import DataPacket @@ -25,7 +25,7 @@ class HttpRequestMethod(Enum): """Apply partial modifications to a resource.""" -class HttpStatusCode(Enum): +class HttpStatusCode(IntEnum): """List of available HTTP Statuses.""" OK = 200 diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index b24b6062..37236e69 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -75,11 +75,11 @@ class DatabaseClient(Application): """ if is_reattempt: if self.connected: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised") self.server_ip_address = server_ip_address return self.connected else: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined") return False payload = {"type": "connect_request", "password": password} software_manager: SoftwareManager = self.software_manager diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ea9c3ac3..0a9c7fc3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -2,6 +2,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional from urllib.parse import urlparse +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -16,6 +17,8 @@ class WebBrowser(Application): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ + target_url: Optional[str] = None + domain_name_ip_address: Optional[IPv4Address] = None "The IP address of the domain name for the webpage." @@ -32,6 +35,14 @@ class WebBrowser(Application): super().__init__(**kwargs) self.run() + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request( + name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa + ) + + return rm + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. @@ -51,7 +62,7 @@ class WebBrowser(Application): self.domain_name_ip_address = None self.latest_response = None - def get_webpage(self, url: str) -> bool: + def get_webpage(self) -> bool: """ Retrieve the webpage. @@ -60,6 +71,7 @@ class WebBrowser(Application): :param: url: The address of the web page the browser requests :type: url: str """ + url = self.target_url # reset latest response self.latest_response = None @@ -71,7 +83,6 @@ class WebBrowser(Application): # get the IP address of the domain name via DNS dns_client: DNSClient = self.software_manager.software["DNSClient"] - domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) # if domain does not exist, the request fails diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index cb1a4738..5dda82d5 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -29,8 +29,9 @@ class WebServer(Service): :rtype: Dict """ state = super().describe_state() + state["last_response_status_code"] = ( - self.last_response_status_code.value if self.last_response_status_code else None + self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) return state @@ -84,6 +85,7 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code + print(self.last_response_status_code) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index f4546cbf..991d6282 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -3,7 +3,6 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.http import HttpStatusCode from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.service import ServiceOperatingState def test_web_page_home_page(uc2_network): @@ -11,9 +10,10 @@ def test_web_page_home_page(uc2_network): client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] web_client.run() + web_client.target_url = "http://arcd.com/" assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage("http://arcd.com/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None @@ -27,7 +27,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network): web_client.run() assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage("http://arcd.com/users/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None @@ -41,11 +41,12 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network): web_client.run() web_server: Server = uc2_network.get_node_by_hostname("web_server") - web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address + web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address + web_client.target_url = f"http://{web_server_ip}/users/" assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage(f"http://{web_server_ip}/users/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None