Minor changes to rewards and services.
This commit is contained in:
@@ -2,10 +2,10 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_learn_steps: 8
|
||||
n_eval_episodes: 0
|
||||
n_eval_steps: 8
|
||||
n_learn_episodes: 4
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 1
|
||||
n_eval_steps: 128
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -534,6 +534,9 @@ simulation:
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
@@ -589,9 +592,10 @@ simulation:
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -180,14 +180,17 @@ class WebServer404Penalty(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
most_recent_return_code = web_service_state["most_recent_return_code"]
|
||||
if web_service_state is NOT_PRESENT_IN_STATE:
|
||||
print("error getting web service state")
|
||||
return 0.0
|
||||
most_recent_return_code = web_service_state["last_response_status_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1
|
||||
return 1.0
|
||||
elif most_recent_return_code == 404:
|
||||
return -1
|
||||
return -1.0
|
||||
else:
|
||||
return 0
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
|
||||
|
||||
@@ -21,12 +21,15 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -182,6 +185,8 @@ class PrimaiteSession:
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
@@ -333,11 +338,11 @@ class PrimaiteSession:
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DatabaseService": DatabaseService,
|
||||
# 'database_backup': ,
|
||||
"WebServer": WebServer,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
# 'web_browser'
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
sess.ref_map_services[service_ref] = new_service
|
||||
@@ -355,6 +360,19 @@ class PrimaiteSession:
|
||||
if "domain_mapping" in opt:
|
||||
for domain, ip in opt["domain_mapping"].items():
|
||||
new_service.dns_register(domain, ip)
|
||||
if "applications" in node_cfg:
|
||||
for application_cfg in node_cfg["applications"]:
|
||||
application_ref = application_cfg["ref"]
|
||||
application_type = application_cfg["type"]
|
||||
application_types_mapping = {
|
||||
"WebBrowser": WebBrowser,
|
||||
}
|
||||
if application_type in application_types_mapping:
|
||||
new_node.software_manager.install(application_types_mapping[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
sess.ref_map_applications[application_ref] = new_application
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
if "nics" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["nics"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
|
||||
@@ -38,7 +38,8 @@ class WebBrowser(Application):
|
||||
|
||||
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
|
||||
"""
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user