#1859 - First pass at an implementation of the full reset method. Will now start testing...

This commit is contained in:
Chris McCarthy
2023-11-27 23:01:56 +00:00
parent ae5046b8fb
commit 58e9033a4c
26 changed files with 360 additions and 240 deletions

View File

@@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)

View File

@@ -25,7 +25,6 @@ the structure:
service_ref: web_server_database_client
```
"""
import json
from abc import abstractmethod
from typing import Dict, List, Tuple, Type, TYPE_CHECKING

View File

@@ -1,5 +1,4 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from ipaddress import IPv4Address
from typing import Dict, List
@@ -11,7 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
@@ -19,7 +18,6 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
@@ -28,7 +26,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -59,10 +56,6 @@ class PrimaiteGame:
"""Initialise a PrimaiteGame object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
print(f"Hello, welcome to PrimaiteGame. This is the ID of the ORIGINAL simulation {id(self.simulation)}")
self._simulation_initial_state = None
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -161,34 +154,7 @@ class PrimaiteGame:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
self._reset_components_for_episode()
print("Reset")
def _reset_components_for_episode(self):
print("Performing full reset for episode")
for node in self.simulation.network.nodes.values():
print(f"Resetting Node: {node.hostname}")
node.reset_component_for_episode(self.episode_counter)
# reset Node NIC
# Reset Node Services
# Reset Node Applications
print(f"Resetting Software...")
for application in node.software_manager.software.values():
print(f"Resetting {application.name}")
if isinstance(application, WebBrowser):
application.do_this()
# Reset Node FileSystem
# Reset Node FileSystemFolder's
# Reset Node FileSystemFile's
# Reset Router
# Reset Links
self.simulation.reset_component_for_episode(episode=self.episode_counter)
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -452,8 +418,6 @@ class PrimaiteGame:
else:
print("agent type not found")
game._simulation_initial_state = deepcopy(game.simulation) # noqa
web_server = game.simulation.network.get_node_by_hostname("web_server").software_manager.software["WebServer"]
print(f"And this is the ID of the original WebServer {id(web_server)}")
game.simulation.set_original_state()
return game

View File

@@ -153,6 +153,8 @@ class SimComponent(BaseModel):
uuid: str
"""The component UUID."""
_original_state: Dict = {}
def __init__(self, **kwargs):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
@@ -160,6 +162,16 @@ class SimComponent(BaseModel):
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
# @abstractmethod
def set_original_state(self):
"""Sets the original state."""
pass
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for key, value in self._original_state.items():
self.__setattr__(key, value)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager for this component.
@@ -227,14 +239,6 @@ class SimComponent(BaseModel):
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Reset this component to its original state for a new episode.
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.

View File

@@ -42,6 +42,19 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"num_logons",
"num_logoffs",
"num_group_changes",
"username",
"password",
"account_type",
"enabled",
}
self._original_state = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -73,6 +73,18 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
@property
def path(self) -> str:
"""

View File

@@ -35,6 +35,36 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def set_original_state(self):
"""Sets the original state."""
for folder in self.folders.values():
folder.set_original_state()
super().set_original_state()
# Capture a list of all 'original' file uuids
self._original_state["original_folder_uuids"] = list(self.folders.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state.pop("original_folder_uuids")
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
self.folders[uuid] = self.deleted_folders.pop(uuid)
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
# Now clear all non-original folders created by agent
current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids:
if uuid not in original_folder_uuids:
self.folders.pop(uuid)
# Now reset all remaining folders
for folder in self.folders.values():
folder.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -51,6 +51,44 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def set_original_state(self):
"""Sets the original state."""
for file in self.files.values():
file.set_original_state()
super().set_original_state()
vals_to_include = {
"scan_duration",
"scan_countdown",
"red_scan_duration",
"red_scan_countdown",
"restore_duration",
"restore_countdown",
}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_file_uuids"] = list(self.files.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state.pop("original_file_uuids")
for uuid in original_file_uuids:
if uuid in self.deleted_files:
self.files[uuid] = self.deleted_files.pop(uuid)
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
# Now clear all non-original files created by agent
current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids:
if uuid not in original_file_uuids:
self.files.pop(uuid)
# Now reset all remaining files
for file in self.files.values():
file.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(

View File

@@ -43,6 +43,20 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def set_original_state(self):
"""Sets the original state."""
for node in self.nodes.values():
node.set_original_state()
for link in self.links.values():
link.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.reset_component_for_episode(episode)
for link in self.links.values():
link.reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()

View File

@@ -121,6 +121,20 @@ class NIC(SimComponent):
_LOGGER.error(msg)
raise ValueError(msg)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -308,6 +322,14 @@ class SwitchPort(SimComponent):
kwargs["mac_address"] = generate_mac_address()
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -454,6 +476,14 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"bandwidth", "current_load"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -536,15 +566,6 @@ class Link(SimComponent):
return True
return False
def reset_component_for_episode(self, episode: int):
"""
Link reset function.
Reset:
- returns the link current_load to 0.
"""
self.current_load = 0
def __str__(self) -> str:
return f"{self.endpoint_a}<-->{self.endpoint_b}"
@@ -584,6 +605,10 @@ class ARPCache:
)
print(table)
def clear(self):
"""Clears the arp cache."""
self.arp.clear()
def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False):
"""
Add an ARP entry to the cache.
@@ -756,6 +781,10 @@ class ICMP:
self.arp: ARPCache = arp_cache
self.request_replies = {}
def clear(self):
"""Clears the ICMP request replies tracker."""
self.request_replies.clear()
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
"""
Process an ICMP packet, including handling echo requests and replies.
@@ -972,6 +1001,55 @@ class Node(SimComponent):
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
self._install_system_software()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
for software in self.software_manager.software.values():
software.set_original_state()
for nic in self.nics.values():
nic.set_original_state()
vals_to_include = {
"hostname",
"default_gateway",
"operating_state",
"revealed_to_red",
"start_up_duration",
"start_up_countdown",
"shut_down_duration",
"shut_down_countdown",
"is_resetting",
"node_scan_duration",
"node_scan_countdown",
"red_scan_countdown",
}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Reset ARP Cache
self.arp.clear()
# Reset ICMP
self.icmp.clear()
# Reset Session Manager
self.session_manager.clear()
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
# Reset all Nics
for nic in self.nics.values():
nic.reset_component_for_episode(episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
@@ -1005,9 +1083,6 @@ class Node(SimComponent):
return rm
def reset_component_for_episode(self, episode: int):
self._init_request_manager()
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
@@ -1425,99 +1500,3 @@ class Node(SimComponent):
if isinstance(item, Service):
return item.uuid in self.services
return None
class Switch(Node):
"""A class representing a Layer 2 network switch."""
num_ports: int = 24
"The number of ports on the switch."
switch_ports: Dict[int, SwitchPort] = {}
"The SwitchPorts on the switch."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port._connected_node = self
port.parent = self
port.port_num = port_num
def show(self):
"""Prints a table of the SwitchPorts on the Switch."""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
for port_num, port in self.switch_ports.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def forward_frame(self, frame: Frame, incoming_port: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame to be forwarded.
:param incoming_port: The port number from which the frame was received.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, incoming_port)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
if port != incoming_port:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.switch_ports.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()

View File

@@ -52,6 +52,11 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -93,6 +98,18 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -638,6 +655,20 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.acl.set_original_state()
vals_to_include = {"num_ports", "route_table"}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.arp.clear()
self.acl.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))

View File

@@ -9,7 +9,7 @@ class Simulation(SimComponent):
"""Top-level simulation object which holds a reference to all other parts of the simulation."""
network: Network
domain: DomainController
# domain: DomainController
def __init__(self, **kwargs):
"""Initialise the Simulation."""
@@ -21,6 +21,14 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
self.network.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
# pass through network requests to the network objects

View File

@@ -38,6 +38,12 @@ class Application(IOSoftware):
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -82,15 +88,6 @@ class Application(IOSoftware):
self.sys_log.info(f"Installing Application {self.name}")
self.operating_state = ApplicationOperatingState.INSTALLING
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.

View File

@@ -31,6 +31,13 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def describe_state(self) -> Dict:
"""

View File

@@ -33,8 +33,15 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.set_original_state()
self.run()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -43,13 +50,6 @@ class WebBrowser(Application):
return rm
def do_this(self):
self._init_request_manager()
print(f"Resetting WebBrowser for episode")
def reset_component_for_episode(self, episode: int):
pass
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the WebBrowser.
@@ -60,14 +60,7 @@ class WebBrowser(Application):
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name_ip_address = None
self.latest_response = None
"""Reset the original state of the SimComponent."""
def get_webpage(self) -> bool:
"""

View File

@@ -34,9 +34,12 @@ class PacketCapture:
"The IP address associated with the PCAP logs."
self.switch_port_number = switch_port_number
"The SwitchPort number."
self._setup_logger()
def _setup_logger(self):
self.current_episode: int = 1
self.setup_logger()
def setup_logger(self):
"""Set up the logger configuration."""
log_path = self._get_log_path()
@@ -75,7 +78,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -93,6 +93,11 @@ class SessionManager:
"""
pass
def clear(self):
"""Clears the sessions."""
self.sessions_by_key.clear()
self.sessions_by_uuid.clear()
@staticmethod
def _get_session_key(
frame: Frame, inbound_frame: bool = True

View File

@@ -31,9 +31,10 @@ class SysLog:
:param hostname: The hostname associated with the system logs being recorded.
"""
self.hostname = hostname
self._setup_logger()
self.current_episode: int = 1
self.setup_logger()
def _setup_logger(self):
def setup_logger(self):
"""
Configures the logger for this SysLog instance.
@@ -80,7 +81,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -24,6 +24,12 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -38,6 +38,23 @@ class DatabaseService(Service):
self._db_file: File
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.connections.clear()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.

View File

@@ -29,6 +29,17 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_cache.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -42,15 +53,6 @@ class DNSClient(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
"""
Adds a domain name to the DNS Client cache.

View File

@@ -28,6 +28,11 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_table.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -62,15 +67,6 @@ class DNSServer(Service):
"""
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(
self,
payload: Any,

View File

@@ -46,6 +46,12 @@ class Service(IOSoftware):
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
@@ -73,15 +79,6 @@ class Service(IOSoftware):
state["health_state_visible"] = self.health_state_visible
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def stop(self) -> None:
"""Stop the service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:

View File

@@ -19,8 +19,14 @@ class WebServer(Service):
_last_response_status_code: Optional[HttpStatusCode] = None
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self._last_response_status_code = None
super().reset_component_for_episode(episode)
@property
def last_response_status_code(self) -> HttpStatusCode:
"""The latest http response code."""
return self._last_response_status_code
@last_response_status_code.setter
@@ -41,14 +47,6 @@ class WebServer(Service):
state["last_response_status_code"] = (
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
print(
f""
f"Printing state from Webserver describe func: "
f"val={state['last_response_status_code']}, "
f"type={type(state['last_response_status_code'])}, "
f"Service obj ID={id(self)}"
)
return state
def __init__(self, **kwargs):
@@ -102,13 +100,6 @@ class WebServer(Service):
# return true if response is OK
self.last_response_status_code = response.status_code
print(
f""
f"Printing state from Webserver http request func: "
f"val={self.last_response_status_code}, "
f"type={type(self.last_response_status_code)}, "
f"Service obj ID={id(self)}"
)
return response.status_code == HttpStatusCode.OK
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:

View File

@@ -89,6 +89,19 @@ class Software(SimComponent):
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"name",
"health_state_actual",
"health_state_visible",
"criticality",
"patching_count",
"scanning_count",
"revealed_to_red",
}
self._original_state = self.model_dump(include=vals_to_include)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -131,16 +144,6 @@ class Software(SimComponent):
)
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the software component for a new episode.
This method should ensure the software is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a
"reset" should be implemented in subclasses.
"""
pass
def set_health_state(self, health_state: SoftwareHealthState) -> None:
"""
Assign a new health state to this software.
@@ -203,6 +206,12 @@ class IOSoftware(Software):
port: Port
"The port to which the software is connected."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""