Finalise the refactor. It works well now.

This commit is contained in:
Marek Wolan
2024-02-25 17:44:41 +00:00
parent a34cf08209
commit 994dbc3501
10 changed files with 65 additions and 54 deletions

View File

@@ -652,12 +652,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server

View File

@@ -185,6 +185,10 @@ class PrimaiteGame:
"""Close the game, this will close the simulation."""
return NotImplemented
def setup_for_episode(self, episode: int) -> None:
"""Perform any final configuration of components to make them ready for the game to start."""
self.simulation.setup_for_episode(episode=episode)
@classmethod
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
@@ -258,7 +262,9 @@ class PrimaiteGame:
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service.uuid
else:
_LOGGER.warning(f"service type not found {service_type}")
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:
@@ -297,7 +303,9 @@ class PrimaiteGame:
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
_LOGGER.warning(f"application type not found {application_type}")
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
if application_type == "DataManipulationBot":
if "options" in application_cfg:

View File

@@ -335,9 +335,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
@@ -347,9 +345,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
@@ -372,9 +368,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# create the env\n",
@@ -385,10 +379,10 @@
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
"game = PrimaiteGame.from_config(cfg)\n",
"env = PrimaiteGymEnv(game = game)\n",
"# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
"env.agent.flatten_obs = False\n",
" # don't flatten observations so that we can see what is going on\n",
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"
@@ -422,9 +416,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"for step in range(35):\n",
@@ -442,9 +434,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"pprint(obs['NODES'])"
@@ -460,9 +450,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
@@ -488,9 +476,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
@@ -515,9 +501,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
@@ -540,9 +524,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
@@ -582,6 +564,22 @@
"obs['ACL']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the cell, you can rerun the other cells to verify that the attack works the same every episode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -592,7 +590,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "venv",
"language": "python",
"name": "python3"
},
@@ -606,9 +604,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 2
}

View File

@@ -74,6 +74,7 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.agent = self.game.rl_agents[0]
self.episode_counter += 1
state = self.game.get_sim_state()
@@ -97,12 +98,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def _get_obs(self) -> ObsType:
"""Return the current observation."""
if not self.agent.flatten_obs:
return self.agent.observation_manager.current_observation
else:
if self.agent.flatten_obs:
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
else:
return self.agent.observation_manager.current_observation
class PrimaiteRayEnv(gymnasium.Env):

View File

@@ -48,9 +48,9 @@ class Network(SimComponent):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.setup_for_episode(episode)
node.setup_for_episode(episode=episode)
for link in self.links.values():
link.setup_for_episode(episode)
link.setup_for_episode(episode=episode)
for node in self.nodes.values():
node.power_on()
@@ -172,7 +172,7 @@ class Network(SimComponent):
def clear_links(self):
"""Clear all the links in the network by resetting their component state for the episode."""
for link in self.links.values():
link.setup_for_episode()
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
def draw(self, seed: int = 123):
"""

View File

@@ -90,7 +90,7 @@ class NetworkInterface(SimComponent, ABC):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode)
super().setup_for_episode(episode=episode)
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
@@ -643,17 +643,17 @@ class Node(SimComponent):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode)
super().setup_for_episode(episode=episode)
# Reset File System
self.file_system.setup_for_episode(episode)
self.file_system.setup_for_episode(episode=episode)
# Reset all Nics
for network_interface in self.network_interfaces.values():
network_interface.setup_for_episode(episode)
network_interface.setup_for_episode(episode=episode)
for software in self.software_manager.software.values():
software.setup_for_episode(episode)
software.setup_for_episode(episode=episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode

View File

@@ -1078,7 +1078,7 @@ class Router(NetworkNode):
for i, _ in self.network_interface.items():
self.enable_port(i)
super().setup_for_episode(episode)
super().setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -23,7 +23,7 @@ class Simulation(SimComponent):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.setup_for_episode(episode)
self.network.setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -118,7 +118,7 @@ class WebServer(Service):
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception:
except Exception: # TODO: refactor this. Likely to cause silent bugs.
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response

View File

@@ -3,7 +3,7 @@ from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
if TYPE_CHECKING:
from primaite.simulator.system.core.software_manager import SoftwareManager
class SoftwareType(Enum):
"""
@@ -84,7 +87,7 @@ class Software(SimComponent):
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
software_manager: Any = None
software_manager: "SoftwareManager" = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."