From 975aa9ffc289271ead984b4b94b06adf7d2011d6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 23 Oct 2023 16:26:34 +0100 Subject: [PATCH] Minor changes to rewards and services. --- example_config.yaml | 14 +++++++----- src/primaite/game/agent/rewards.py | 13 ++++++----- src/primaite/game/session.py | 22 +++++++++++++++++-- .../system/applications/web_browser.py | 3 ++- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/example_config.yaml b/example_config.yaml index e16411fa..f3d8dc10 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -2,10 +2,10 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 1 - n_learn_steps: 8 - n_eval_episodes: 0 - n_eval_steps: 8 + n_learn_episodes: 4 + n_learn_steps: 128 + n_eval_episodes: 1 + n_eval_steps: 128 game_config: @@ -534,6 +534,9 @@ simulation: type: DatabaseClient options: db_server_ip: 192.168.1.14 + - ref: web_server_web_service + type: WebServer + - ref: database_server type: server @@ -589,9 +592,10 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 - services: + applications: - ref: client_2_web_browser type: WebBrowser + services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 03c4e2d3..6c408ff9 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -29,7 +29,7 @@ from abc import abstractmethod from typing import Dict, List, Tuple, TYPE_CHECKING from primaite import getLogger -from primaite.game.agent.utils import access_from_nested_dict +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -180,14 +180,17 @@ class WebServer404Penalty(AbstractReward): :type state: Dict """ web_service_state = access_from_nested_dict(state, self.location_in_state) - most_recent_return_code = web_service_state["most_recent_return_code"] + if web_service_state is NOT_PRESENT_IN_STATE: + print("error getting web service state") + return 0.0 + most_recent_return_code = web_service_state["last_response_status_code"] # TODO: reward needs to use the current web state. Observation should return web state at the time of last scan. if most_recent_return_code == 200: - return 1 + return 1.0 elif most_recent_return_code == 404: - return -1 + return -1.0 else: - return 0 + return 0.0 @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty": diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index adb9f7b5..d40d0754 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -21,12 +21,15 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -182,6 +185,8 @@ class PrimaiteSession: """Mapping from unique node reference name to node object. Used when parsing config files.""" self.ref_map_services: Dict[str, Service] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" + self.ref_map_applications: Dict[str, Application] = {} + """Mapping from human-readable application reference to application object. Used for parsing config files.""" self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self) @@ -333,11 +338,11 @@ class PrimaiteSession: "DNSServer": DNSServer, "DatabaseClient": DatabaseClient, "DatabaseService": DatabaseService, - # 'database_backup': , + "WebServer": WebServer, "DataManipulationBot": DataManipulationBot, - # 'web_browser' } if service_type in service_types_mapping: + print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] sess.ref_map_services[service_ref] = new_service @@ -355,6 +360,19 @@ class PrimaiteSession: if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): new_service.dns_register(domain, ip) + if "applications" in node_cfg: + for application_cfg in node_cfg["applications"]: + application_ref = application_cfg["ref"] + application_type = application_cfg["type"] + application_types_mapping = { + "WebBrowser": WebBrowser, + } + if application_type in application_types_mapping: + new_node.software_manager.install(application_types_mapping[application_type]) + new_application = new_node.software_manager.software[application_type] + sess.ref_map_applications[application_ref] = new_application + else: + print(f"application type not found {application_type}") if "nics" in node_cfg: for nic_num, nic_cfg in node_cfg["nics"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index c48b785e..ea9c3ac3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -38,7 +38,8 @@ class WebBrowser(Application): :return: A dictionary capturing the current state of the WebBrowser and its child objects. """ - return super().describe_state() + state = super().describe_state() + state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None def reset_component_for_episode(self, episode: int): """