Finalise the refactor. It works well now.
This commit is contained in:
@@ -652,12 +652,13 @@ simulation:
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
|
||||
@@ -185,6 +185,10 @@ class PrimaiteGame:
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
def setup_for_episode(self, episode: int) -> None:
|
||||
"""Perform any final configuration of components to make them ready for the game to start."""
|
||||
self.simulation.setup_for_episode(episode=episode)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
@@ -258,7 +262,9 @@ class PrimaiteGame:
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
game.ref_map_services[service_ref] = new_service.uuid
|
||||
else:
|
||||
_LOGGER.warning(f"service type not found {service_type}")
|
||||
msg = f"Configuration contains an invalid service type: {service_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
# service-dependent options
|
||||
if service_type == "DNSClient":
|
||||
if "options" in service_cfg:
|
||||
@@ -297,7 +303,9 @@ class PrimaiteGame:
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
game.ref_map_applications[application_ref] = new_application.uuid
|
||||
else:
|
||||
_LOGGER.warning(f"application type not found {application_type}")
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
|
||||
@@ -335,9 +335,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
@@ -347,9 +345,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
@@ -372,9 +368,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create the env\n",
|
||||
@@ -385,10 +379,10 @@
|
||||
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
|
||||
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
|
||||
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
|
||||
"game = PrimaiteGame.from_config(cfg)\n",
|
||||
"env = PrimaiteGymEnv(game = game)\n",
|
||||
"# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
|
||||
"env.agent.flatten_obs = False\n",
|
||||
" # don't flatten observations so that we can see what is going on\n",
|
||||
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"obs, info = env.reset()\n",
|
||||
"print('env created successfully')\n",
|
||||
"pprint(obs)"
|
||||
@@ -422,9 +416,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(35):\n",
|
||||
@@ -442,9 +434,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pprint(obs['NODES'])"
|
||||
@@ -460,9 +450,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
|
||||
@@ -488,9 +476,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
@@ -515,9 +501,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
@@ -540,9 +524,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
@@ -582,6 +564,22 @@
|
||||
"obs['ACL']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reset the cell, you can rerun the other cells to verify that the attack works the same every episode."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -592,7 +590,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -606,9 +604,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@@ -74,6 +74,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
|
||||
)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.agent = self.game.rl_agents[0]
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
@@ -97,12 +98,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
if not self.agent.flatten_obs:
|
||||
return self.agent.observation_manager.current_observation
|
||||
else:
|
||||
if self.agent.flatten_obs:
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
else:
|
||||
return self.agent.observation_manager.current_observation
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
|
||||
@@ -48,9 +48,9 @@ class Network(SimComponent):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
for node in self.nodes.values():
|
||||
node.setup_for_episode(episode)
|
||||
node.setup_for_episode(episode=episode)
|
||||
for link in self.links.values():
|
||||
link.setup_for_episode(episode)
|
||||
link.setup_for_episode(episode=episode)
|
||||
|
||||
for node in self.nodes.values():
|
||||
node.power_on()
|
||||
@@ -172,7 +172,7 @@ class Network(SimComponent):
|
||||
def clear_links(self):
|
||||
"""Clear all the links in the network by resetting their component state for the episode."""
|
||||
for link in self.links.values():
|
||||
link.setup_for_episode()
|
||||
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
|
||||
|
||||
def draw(self, seed: int = 123):
|
||||
"""
|
||||
|
||||
@@ -90,7 +90,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode)
|
||||
super().setup_for_episode(episode=episode)
|
||||
if episode and self.pcap:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
@@ -643,17 +643,17 @@ class Node(SimComponent):
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode)
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
# Reset File System
|
||||
self.file_system.setup_for_episode(episode)
|
||||
self.file_system.setup_for_episode(episode=episode)
|
||||
|
||||
# Reset all Nics
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.setup_for_episode(episode)
|
||||
network_interface.setup_for_episode(episode=episode)
|
||||
|
||||
for software in self.software_manager.software.values():
|
||||
software.setup_for_episode(episode)
|
||||
software.setup_for_episode(episode=episode)
|
||||
|
||||
if episode and self.sys_log:
|
||||
self.sys_log.current_episode = episode
|
||||
|
||||
@@ -1078,7 +1078,7 @@ class Router(NetworkNode):
|
||||
for i, _ in self.network_interface.items():
|
||||
self.enable_port(i)
|
||||
|
||||
super().setup_for_episode(episode)
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -23,7 +23,7 @@ class Simulation(SimComponent):
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.network.setup_for_episode(episode)
|
||||
self.network.setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -118,7 +118,7 @@ class WebServer(Service):
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception:
|
||||
except Exception: # TODO: refactor this. Likely to cause silent bugs.
|
||||
# something went wrong on the server
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class SoftwareType(Enum):
|
||||
"""
|
||||
@@ -84,7 +87,7 @@ class Software(SimComponent):
|
||||
"The count of times the software has been scanned, defaults to 0."
|
||||
revealed_to_red: bool = False
|
||||
"Indicates if the software has been revealed to red agent, defaults is False."
|
||||
software_manager: Any = None
|
||||
software_manager: "SoftwareManager" = None
|
||||
"An instance of Software Manager that is used by the parent node."
|
||||
sys_log: SysLog = None
|
||||
"An instance of SysLog that is used by the parent node."
|
||||
|
||||
Reference in New Issue
Block a user