Minor changes to rewards and services.

This commit is contained in:
Marek Wolan
2023-10-23 16:26:34 +01:00
parent 0f24b4a646
commit 975aa9ffc2
4 changed files with 39 additions and 13 deletions

View File

@@ -2,10 +2,10 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_learn_steps: 8
n_eval_episodes: 0
n_eval_steps: 8
n_learn_episodes: 4
n_learn_steps: 128
n_eval_episodes: 1
n_eval_steps: 128
game_config:
@@ -534,6 +534,9 @@ simulation:
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
@@ -589,9 +592,10 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -29,7 +29,7 @@ from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
@@ -180,14 +180,17 @@ class WebServer404Penalty(AbstractReward):
:type state: Dict
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
most_recent_return_code = web_service_state["most_recent_return_code"]
if web_service_state is NOT_PRESENT_IN_STATE:
print("error getting web service state")
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1
return 1.0
elif most_recent_return_code == 404:
return -1
return -1.0
else:
return 0
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":

View File

@@ -21,12 +21,15 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -182,6 +185,8 @@ class PrimaiteSession:
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
@@ -333,11 +338,11 @@ class PrimaiteSession:
"DNSServer": DNSServer,
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
# 'database_backup': ,
"WebServer": WebServer,
"DataManipulationBot": DataManipulationBot,
# 'web_browser'
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
@@ -355,6 +360,19 @@ class PrimaiteSession:
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))

View File

@@ -38,7 +38,8 @@ class WebBrowser(Application):
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
"""
return super().describe_state()
state = super().describe_state()
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
def reset_component_for_episode(self, episode: int):
"""