#1859 - First pass at an implementation of the full reset method. Will now start testing...
This commit is contained in:
@@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ the structure:
|
||||
service_ref: web_server_database_client
|
||||
```
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from copy import deepcopy
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -11,7 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
@@ -19,7 +18,6 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
@@ -28,7 +26,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -59,10 +56,6 @@ class PrimaiteGame:
|
||||
"""Initialise a PrimaiteGame object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
print(f"Hello, welcome to PrimaiteGame. This is the ID of the ORIGINAL simulation {id(self.simulation)}")
|
||||
|
||||
self._simulation_initial_state = None
|
||||
"""The Simulation original state (deepcopy of the original Simulation)."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
@@ -161,34 +154,7 @@ class PrimaiteGame:
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation = deepcopy(self._simulation_initial_state)
|
||||
self._reset_components_for_episode()
|
||||
print("Reset")
|
||||
|
||||
def _reset_components_for_episode(self):
|
||||
print("Performing full reset for episode")
|
||||
for node in self.simulation.network.nodes.values():
|
||||
print(f"Resetting Node: {node.hostname}")
|
||||
node.reset_component_for_episode(self.episode_counter)
|
||||
|
||||
# reset Node NIC
|
||||
|
||||
# Reset Node Services
|
||||
|
||||
# Reset Node Applications
|
||||
print(f"Resetting Software...")
|
||||
for application in node.software_manager.software.values():
|
||||
print(f"Resetting {application.name}")
|
||||
if isinstance(application, WebBrowser):
|
||||
application.do_this()
|
||||
|
||||
# Reset Node FileSystem
|
||||
# Reset Node FileSystemFolder's
|
||||
# Reset Node FileSystemFile's
|
||||
|
||||
# Reset Router
|
||||
|
||||
# Reset Links
|
||||
self.simulation.reset_component_for_episode(episode=self.episode_counter)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
@@ -452,8 +418,6 @@ class PrimaiteGame:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
game._simulation_initial_state = deepcopy(game.simulation) # noqa
|
||||
web_server = game.simulation.network.get_node_by_hostname("web_server").software_manager.software["WebServer"]
|
||||
print(f"And this is the ID of the original WebServer {id(web_server)}")
|
||||
game.simulation.set_original_state()
|
||||
|
||||
return game
|
||||
|
||||
@@ -153,6 +153,8 @@ class SimComponent(BaseModel):
|
||||
uuid: str
|
||||
"""The component UUID."""
|
||||
|
||||
_original_state: Dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("uuid"):
|
||||
kwargs["uuid"] = str(uuid4())
|
||||
@@ -160,6 +162,16 @@ class SimComponent(BaseModel):
|
||||
self._request_manager: RequestManager = self._init_request_manager()
|
||||
self._parent: Optional["SimComponent"] = None
|
||||
|
||||
# @abstractmethod
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
pass
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
for key, value in self._original_state.items():
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager for this component.
|
||||
@@ -227,14 +239,6 @@ class SimComponent(BaseModel):
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Reset this component to its original state for a new episode.
|
||||
|
||||
Override this method with anything that needs to happen within the component for it to be reset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def parent(self) -> "SimComponent":
|
||||
"""Reference to the parent object which manages this object.
|
||||
|
||||
@@ -42,6 +42,19 @@ class Account(SimComponent):
|
||||
"Account Type, currently this can be service account (used by apps) or user account."
|
||||
enabled: bool = True
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {
|
||||
"num_logons",
|
||||
"num_logoffs",
|
||||
"num_group_changes",
|
||||
"username",
|
||||
"password",
|
||||
"account_type",
|
||||
"enabled",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -73,6 +73,18 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -35,6 +35,36 @@ class FileSystem(SimComponent):
|
||||
if not self.folders:
|
||||
self.create_folder("root")
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for folder in self.folders.values():
|
||||
folder.set_original_state()
|
||||
super().set_original_state()
|
||||
# Capture a list of all 'original' file uuids
|
||||
self._original_state["original_folder_uuids"] = list(self.folders.keys())
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
# Move any 'original' folder that have been deleted back to folders
|
||||
original_folder_uuids = self._original_state.pop("original_folder_uuids")
|
||||
for uuid in original_folder_uuids:
|
||||
if uuid in self.deleted_folders:
|
||||
self.folders[uuid] = self.deleted_folders.pop(uuid)
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
|
||||
# Now clear all non-original folders created by agent
|
||||
current_folder_uuids = list(self.folders.keys())
|
||||
for uuid in current_folder_uuids:
|
||||
if uuid not in original_folder_uuids:
|
||||
self.folders.pop(uuid)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
folder.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent):
|
||||
deleted: bool = False
|
||||
"If true, the FileSystemItem was deleted."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -51,6 +51,44 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for file in self.files.values():
|
||||
file.set_original_state()
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"scan_duration",
|
||||
"scan_countdown",
|
||||
"red_scan_duration",
|
||||
"red_scan_countdown",
|
||||
"restore_duration",
|
||||
"restore_countdown",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
self._original_state["original_file_uuids"] = list(self.files.keys())
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
# Move any 'original' file that have been deleted back to files
|
||||
original_file_uuids = self._original_state.pop("original_file_uuids")
|
||||
for uuid in original_file_uuids:
|
||||
if uuid in self.deleted_files:
|
||||
self.files[uuid] = self.deleted_files.pop(uuid)
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
|
||||
# Now clear all non-original files created by agent
|
||||
current_file_uuids = list(self.files.keys())
|
||||
for uuid in current_file_uuids:
|
||||
if uuid not in original_file_uuids:
|
||||
self.files.pop(uuid)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
file.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
|
||||
@@ -43,6 +43,20 @@ class Network(SimComponent):
|
||||
|
||||
self._nx_graph = MultiGraph()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for node in self.nodes.values():
|
||||
node.set_original_state()
|
||||
for link in self.links.values():
|
||||
link.set_original_state()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
for node in self.nodes.values():
|
||||
node.reset_component_for_episode(episode)
|
||||
for link in self.links.values():
|
||||
link.reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
self._node_request_manager = RequestManager()
|
||||
|
||||
@@ -121,6 +121,20 @@ class NIC(SimComponent):
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
if episode and self.pcap:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -308,6 +322,14 @@ class SwitchPort(SimComponent):
|
||||
kwargs["mac_address"] = generate_mac_address()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
super().set_original_state()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -454,6 +476,14 @@ class Link(SimComponent):
|
||||
self.endpoint_b.connect_link(self)
|
||||
self.endpoint_up()
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"bandwidth", "current_load"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
super().set_original_state()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -536,15 +566,6 @@ class Link(SimComponent):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Link reset function.
|
||||
|
||||
Reset:
|
||||
- returns the link current_load to 0.
|
||||
"""
|
||||
self.current_load = 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.endpoint_a}<-->{self.endpoint_b}"
|
||||
|
||||
@@ -584,6 +605,10 @@ class ARPCache:
|
||||
)
|
||||
print(table)
|
||||
|
||||
def clear(self):
|
||||
"""Clears the arp cache."""
|
||||
self.arp.clear()
|
||||
|
||||
def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False):
|
||||
"""
|
||||
Add an ARP entry to the cache.
|
||||
@@ -756,6 +781,10 @@ class ICMP:
|
||||
self.arp: ARPCache = arp_cache
|
||||
self.request_replies = {}
|
||||
|
||||
def clear(self):
|
||||
"""Clears the ICMP request replies tracker."""
|
||||
self.request_replies.clear()
|
||||
|
||||
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
|
||||
"""
|
||||
Process an ICMP packet, including handling echo requests and replies.
|
||||
@@ -972,6 +1001,55 @@ class Node(SimComponent):
|
||||
self.arp.nics = self.nics
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self._install_system_software()
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for software in self.software_manager.software.values():
|
||||
software.set_original_state()
|
||||
|
||||
for nic in self.nics.values():
|
||||
nic.set_original_state()
|
||||
|
||||
vals_to_include = {
|
||||
"hostname",
|
||||
"default_gateway",
|
||||
"operating_state",
|
||||
"revealed_to_red",
|
||||
"start_up_duration",
|
||||
"start_up_countdown",
|
||||
"shut_down_duration",
|
||||
"shut_down_countdown",
|
||||
"is_resetting",
|
||||
"node_scan_duration",
|
||||
"node_scan_countdown",
|
||||
"red_scan_countdown",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
# Reset ARP Cache
|
||||
self.arp.clear()
|
||||
|
||||
# Reset ICMP
|
||||
self.icmp.clear()
|
||||
|
||||
# Reset Session Manager
|
||||
self.session_manager.clear()
|
||||
|
||||
for software in self.software_manager.software.values():
|
||||
software.reset_component_for_episode(episode)
|
||||
|
||||
# Reset all Nics
|
||||
for nic in self.nics.values():
|
||||
nic.reset_component_for_episode(episode)
|
||||
|
||||
if episode and self.sys_log:
|
||||
self.sys_log.current_episode = episode
|
||||
self.sys_log.setup_logger()
|
||||
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
|
||||
@@ -1005,9 +1083,6 @@ class Node(SimComponent):
|
||||
|
||||
return rm
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
self._init_request_manager()
|
||||
|
||||
def _install_system_software(self):
|
||||
"""Install System Software - software that is usually provided with the OS."""
|
||||
pass
|
||||
@@ -1425,99 +1500,3 @@ class Node(SimComponent):
|
||||
if isinstance(item, Service):
|
||||
return item.uuid in self.services
|
||||
return None
|
||||
|
||||
|
||||
class Switch(Node):
|
||||
"""A class representing a Layer 2 network switch."""
|
||||
|
||||
num_ports: int = 24
|
||||
"The number of ports on the switch."
|
||||
switch_ports: Dict[int, SwitchPort] = {}
|
||||
"The SwitchPorts on the switch."
|
||||
mac_address_table: Dict[str, SwitchPort] = {}
|
||||
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not self.switch_ports:
|
||||
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.switch_ports.items():
|
||||
port._connected_node = self
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
|
||||
def show(self):
|
||||
"""Prints a table of the SwitchPorts on the Switch."""
|
||||
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
|
||||
|
||||
for port_num, port in self.switch_ports.items():
|
||||
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
|
||||
print(table)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"num_ports": self.num_ports, # redundant?
|
||||
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
|
||||
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
|
||||
}
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
mac_table_port = self.mac_address_table.get(mac_address)
|
||||
if not mac_table_port:
|
||||
self.mac_address_table[mac_address] = switch_port
|
||||
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
|
||||
else:
|
||||
if mac_table_port != switch_port:
|
||||
self.mac_address_table.pop(mac_address)
|
||||
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
|
||||
self._add_mac_table_entry(mac_address, switch_port)
|
||||
|
||||
def forward_frame(self, frame: Frame, incoming_port: SwitchPort):
|
||||
"""
|
||||
Forward a frame to the appropriate port based on the destination MAC address.
|
||||
|
||||
:param frame: The Frame to be forwarded.
|
||||
:param incoming_port: The port number from which the frame was received.
|
||||
"""
|
||||
src_mac = frame.ethernet.src_mac_addr
|
||||
dst_mac = frame.ethernet.dst_mac_addr
|
||||
self._add_mac_table_entry(src_mac, incoming_port)
|
||||
|
||||
outgoing_port = self.mac_address_table.get(dst_mac)
|
||||
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
|
||||
outgoing_port.send_frame(frame)
|
||||
else:
|
||||
# If the destination MAC is not in the table, flood to all ports except incoming
|
||||
for port in self.switch_ports.values():
|
||||
if port != incoming_port:
|
||||
port.send_frame(frame)
|
||||
|
||||
def disconnect_link_from_port(self, link: Link, port_number: int):
|
||||
"""
|
||||
Disconnect a given link from the specified port number on the switch.
|
||||
|
||||
:param link: The Link object to be disconnected.
|
||||
:param port_number: The port number on the switch from where the link should be disconnected.
|
||||
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
|
||||
"""
|
||||
port = self.switch_ports.get(port_number)
|
||||
if port is None:
|
||||
msg = f"Invalid port number {port_number} on the switch"
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
if port._connected_link != link:
|
||||
msg = f"The link does not match the connection at port number {port_number}"
|
||||
_LOGGER.error(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
port.disconnect_link()
|
||||
|
||||
@@ -52,6 +52,11 @@ class ACLRule(SimComponent):
|
||||
rule_strings.append(f"{key}={value}")
|
||||
return ", ".join(rule_strings)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the ACLRule.
|
||||
@@ -93,6 +98,18 @@ class AccessControlList(SimComponent):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
self.implicit_rule.set_original_state()
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -638,6 +655,20 @@ class Router(Node):
|
||||
self.arp.nics = self.nics
|
||||
self.icmp.arp = self.arp
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
self.acl.set_original_state()
|
||||
vals_to_include = {"num_ports", "route_table"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.arp.clear()
|
||||
self.acl.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("acl", RequestType(func=self.acl._request_manager))
|
||||
|
||||
@@ -9,7 +9,7 @@ class Simulation(SimComponent):
|
||||
"""Top-level simulation object which holds a reference to all other parts of the simulation."""
|
||||
|
||||
network: Network
|
||||
domain: DomainController
|
||||
# domain: DomainController
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the Simulation."""
|
||||
@@ -21,6 +21,14 @@ class Simulation(SimComponent):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
self.network.set_original_state()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.network.reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
# pass through network requests to the network objects
|
||||
|
||||
@@ -38,6 +38,12 @@ class Application(IOSoftware):
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -82,15 +88,6 @@ class Application(IOSoftware):
|
||||
self.sys_log.info(f"Installing Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.INSTALLING
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Application component for a new episode.
|
||||
|
||||
This method ensures the Application is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
@@ -31,6 +31,13 @@ class DatabaseClient(Application):
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"server_ip_address", "server_password", "connected"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -33,8 +33,15 @@ class WebBrowser(Application):
|
||||
kwargs["port"] = Port.HTTP
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.set_original_state()
|
||||
self.run()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -43,13 +50,6 @@ class WebBrowser(Application):
|
||||
|
||||
return rm
|
||||
|
||||
def do_this(self):
|
||||
self._init_request_manager()
|
||||
print(f"Resetting WebBrowser for episode")
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
pass
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of the WebBrowser.
|
||||
@@ -60,14 +60,7 @@ class WebBrowser(Application):
|
||||
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Application component for a new episode.
|
||||
|
||||
This method ensures the Application is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
self.domain_name_ip_address = None
|
||||
self.latest_response = None
|
||||
"""Reset the original state of the SimComponent."""
|
||||
|
||||
def get_webpage(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -34,9 +34,12 @@ class PacketCapture:
|
||||
"The IP address associated with the PCAP logs."
|
||||
self.switch_port_number = switch_port_number
|
||||
"The SwitchPort number."
|
||||
self._setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
self.current_episode: int = 1
|
||||
|
||||
self.setup_logger()
|
||||
|
||||
def setup_logger(self):
|
||||
"""Set up the logger configuration."""
|
||||
log_path = self._get_log_path()
|
||||
|
||||
@@ -75,7 +78,7 @@ class PacketCapture:
|
||||
|
||||
def _get_log_path(self) -> Path:
|
||||
"""Get the path for the log file."""
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self._logger_name}.log"
|
||||
|
||||
|
||||
@@ -93,6 +93,11 @@ class SessionManager:
|
||||
"""
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
"""Clears the sessions."""
|
||||
self.sessions_by_key.clear()
|
||||
self.sessions_by_uuid.clear()
|
||||
|
||||
@staticmethod
|
||||
def _get_session_key(
|
||||
frame: Frame, inbound_frame: bool = True
|
||||
|
||||
@@ -31,9 +31,10 @@ class SysLog:
|
||||
:param hostname: The hostname associated with the system logs being recorded.
|
||||
"""
|
||||
self.hostname = hostname
|
||||
self._setup_logger()
|
||||
self.current_episode: int = 1
|
||||
self.setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
def setup_logger(self):
|
||||
"""
|
||||
Configures the logger for this SysLog instance.
|
||||
|
||||
@@ -80,7 +81,7 @@ class SysLog:
|
||||
|
||||
:return: Path object representing the location of the log file.
|
||||
"""
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self.hostname}_sys.log"
|
||||
|
||||
|
||||
@@ -24,6 +24,12 @@ class Process(Software):
|
||||
operating_state: ProcessOperatingState
|
||||
"The current operating state of the Process."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -38,6 +38,23 @@ class DatabaseService(Service):
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"connections",
|
||||
"backup_server",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.connections.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
Set up the database backup.
|
||||
|
||||
@@ -29,6 +29,17 @@ class DNSClient(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"dns_server"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_cache.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
@@ -42,15 +53,6 @@ class DNSClient(Service):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
|
||||
"""
|
||||
Adds a domain name to the DNS Client cache.
|
||||
|
||||
@@ -28,6 +28,11 @@ class DNSServer(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_table.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
@@ -62,15 +67,6 @@ class DNSServer(Service):
|
||||
"""
|
||||
self.dns_table[domain_name] = domain_ip_address
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
|
||||
@@ -46,6 +46,12 @@ class Service(IOSoftware):
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
@@ -73,15 +79,6 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
|
||||
@@ -19,8 +19,14 @@ class WebServer(Service):
|
||||
|
||||
_last_response_status_code: Optional[HttpStatusCode] = None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self._last_response_status_code = None
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
@property
|
||||
def last_response_status_code(self) -> HttpStatusCode:
|
||||
"""The latest http response code."""
|
||||
return self._last_response_status_code
|
||||
|
||||
@last_response_status_code.setter
|
||||
@@ -41,14 +47,6 @@ class WebServer(Service):
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
|
||||
)
|
||||
|
||||
print(
|
||||
f""
|
||||
f"Printing state from Webserver describe func: "
|
||||
f"val={state['last_response_status_code']}, "
|
||||
f"type={type(state['last_response_status_code'])}, "
|
||||
f"Service obj ID={id(self)}"
|
||||
)
|
||||
return state
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -102,13 +100,6 @@ class WebServer(Service):
|
||||
# return true if response is OK
|
||||
self.last_response_status_code = response.status_code
|
||||
|
||||
print(
|
||||
f""
|
||||
f"Printing state from Webserver http request func: "
|
||||
f"val={self.last_response_status_code}, "
|
||||
f"type={type(self.last_response_status_code)}, "
|
||||
f"Service obj ID={id(self)}"
|
||||
)
|
||||
return response.status_code == HttpStatusCode.OK
|
||||
|
||||
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
|
||||
|
||||
@@ -89,6 +89,19 @@ class Software(SimComponent):
|
||||
folder: Optional[Folder] = None
|
||||
"The folder on the file system the Software uses."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {
|
||||
"name",
|
||||
"health_state_actual",
|
||||
"health_state_visible",
|
||||
"criticality",
|
||||
"patching_count",
|
||||
"scanning_count",
|
||||
"revealed_to_red",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -131,16 +144,6 @@ class Software(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the software component for a new episode.
|
||||
|
||||
This method should ensure the software is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a
|
||||
"reset" should be implemented in subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> None:
|
||||
"""
|
||||
Assign a new health state to this software.
|
||||
@@ -203,6 +206,12 @@ class IOSoftware(Software):
|
||||
port: Port
|
||||
"The port to which the software is connected."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user