diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 62e56c6e..c70d4d66 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces from primaite import getLogger -from primaite.simulator.sim_container import Simulation _LOGGER = getLogger(__name__) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 7cca9116..3466114c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -25,7 +25,6 @@ the structure: service_ref: web_server_database_client ``` """ -import json from abc import abstractmethod from typing import Dict, List, Tuple, Type, TYPE_CHECKING diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 147ed499..38e9d5fc 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,5 +1,4 @@ """PrimAITE game - Encapsulates the simulation and agents.""" -from copy import deepcopy from ipaddress import IPv4Address from typing import Dict, List @@ -11,7 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState +from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server @@ -19,7 +18,6 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -28,7 +26,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -59,10 +56,6 @@ class PrimaiteGame: """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - print(f"Hello, welcome to PrimaiteGame. This is the ID of the ORIGINAL simulation {id(self.simulation)}") - - self._simulation_initial_state = None - """The Simulation original state (deepcopy of the original Simulation).""" self.agents: List[AbstractAgent] = [] """List of agents.""" @@ -161,34 +154,7 @@ class PrimaiteGame: self.episode_counter += 1 self.step_counter = 0 _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation = deepcopy(self._simulation_initial_state) - self._reset_components_for_episode() - print("Reset") - - def _reset_components_for_episode(self): - print("Performing full reset for episode") - for node in self.simulation.network.nodes.values(): - print(f"Resetting Node: {node.hostname}") - node.reset_component_for_episode(self.episode_counter) - - # reset Node NIC - - # Reset Node Services - - # Reset Node Applications - print(f"Resetting Software...") - for application in node.software_manager.software.values(): - print(f"Resetting {application.name}") - if isinstance(application, WebBrowser): - application.do_this() - - # Reset Node FileSystem - # Reset Node FileSystemFolder's - # Reset Node FileSystemFile's - - # Reset Router - - # Reset Links + self.simulation.reset_component_for_episode(episode=self.episode_counter) def close(self) -> None: """Close the game, this will close the simulation.""" @@ -452,8 +418,6 @@ class PrimaiteGame: else: print("agent type not found") - game._simulation_initial_state = deepcopy(game.simulation) # noqa - web_server = game.simulation.network.get_node_by_hostname("web_server").software_manager.software["WebServer"] - print(f"And this is the ID of the original WebServer {id(web_server)}") + game.simulation.set_original_state() return game diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 9ead877e..18a470cd 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,6 +153,8 @@ class SimComponent(BaseModel): uuid: str """The component UUID.""" + _original_state: Dict = {} + def __init__(self, **kwargs): if not kwargs.get("uuid"): kwargs["uuid"] = str(uuid4()) @@ -160,6 +162,16 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None + # @abstractmethod + def set_original_state(self): + """Sets the original state.""" + pass + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for key, value in self._original_state.items(): + self.__setattr__(key, value) + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager for this component. @@ -227,14 +239,6 @@ class SimComponent(BaseModel): """ pass - def reset_component_for_episode(self, episode: int): - """ - Reset this component to its original state for a new episode. - - Override this method with anything that needs to happen within the component for it to be reset. - """ - pass - @property def parent(self) -> "SimComponent": """Reference to the parent object which manages this object. diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d235c00e..1402a474 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,6 +42,19 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "num_logons", + "num_logoffs", + "num_group_changes", + "username", + "password", + "account_type", + "enabled", + } + self._original_state = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index d9b02e8e..8f0abb3c 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,6 +73,18 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 41f02270..dc6f01a3 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -35,6 +35,36 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") + def set_original_state(self): + """Sets the original state.""" + for folder in self.folders.values(): + folder.set_original_state() + super().set_original_state() + # Capture a list of all 'original' file uuids + self._original_state["original_folder_uuids"] = list(self.folders.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' folder that have been deleted back to folders + original_folder_uuids = self._original_state.pop("original_folder_uuids") + for uuid in original_folder_uuids: + if uuid in self.deleted_folders: + self.folders[uuid] = self.deleted_folders.pop(uuid) + + # Clear any other deleted folders that aren't original (have been created by agent) + self.deleted_folders.clear() + + # Now clear all non-original folders created by agent + current_folder_uuids = list(self.folders.keys()) + for uuid in current_folder_uuids: + if uuid not in original_folder_uuids: + self.folders.pop(uuid) + + # Now reset all remaining folders + for folder in self.folders.values(): + folder.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index fbe5f4b3..86cd1ee7 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"} + self._original_state = self.model_dump(include=vals_to_keep) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index f0d55ef8..8e577097 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -51,6 +51,44 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") + def set_original_state(self): + """Sets the original state.""" + for file in self.files.values(): + file.set_original_state() + super().set_original_state() + vals_to_include = { + "scan_duration", + "scan_countdown", + "red_scan_duration", + "red_scan_countdown", + "restore_duration", + "restore_countdown", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + self._original_state["original_file_uuids"] = list(self.files.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' file that have been deleted back to files + original_file_uuids = self._original_state.pop("original_file_uuids") + for uuid in original_file_uuids: + if uuid in self.deleted_files: + self.files[uuid] = self.deleted_files.pop(uuid) + + # Clear any other deleted files that aren't original (have been created by agent) + self.deleted_files.clear() + + # Now clear all non-original files created by agent + current_file_uuids = list(self.files.keys()) + for uuid in current_file_uuids: + if uuid not in original_file_uuids: + self.files.pop(uuid) + + # Now reset all remaining files + for file in self.files.values(): + file.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 9fbafc29..cab983c7 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -43,6 +43,20 @@ class Network(SimComponent): self._nx_graph = MultiGraph() + def set_original_state(self): + """Sets the original state.""" + for node in self.nodes.values(): + node.set_original_state() + for link in self.links.values(): + link.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for node in self.nodes.values(): + node.reset_component_for_episode(episode) + for link in self.links.values(): + link.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() self._node_request_manager = RequestManager() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0717f813..2863dd22 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -121,6 +121,20 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + if episode and self.pcap: + self.pcap.current_episode = episode + self.pcap.setup_logger() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -308,6 +322,14 @@ class SwitchPort(SimComponent): kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -454,6 +476,14 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"bandwidth", "current_load"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -536,15 +566,6 @@ class Link(SimComponent): return True return False - def reset_component_for_episode(self, episode: int): - """ - Link reset function. - - Reset: - - returns the link current_load to 0. - """ - self.current_load = 0 - def __str__(self) -> str: return f"{self.endpoint_a}<-->{self.endpoint_b}" @@ -584,6 +605,10 @@ class ARPCache: ) print(table) + def clear(self): + """Clears the arp cache.""" + self.arp.clear() + def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False): """ Add an ARP entry to the cache. @@ -756,6 +781,10 @@ class ICMP: self.arp: ARPCache = arp_cache self.request_replies = {} + def clear(self): + """Clears the ICMP request replies tracker.""" + self.request_replies.clear() + def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False): """ Process an ICMP packet, including handling echo requests and replies. @@ -972,6 +1001,55 @@ class Node(SimComponent): self.arp.nics = self.nics self.session_manager.software_manager = self.software_manager self._install_system_software() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + for software in self.software_manager.software.values(): + software.set_original_state() + + for nic in self.nics.values(): + nic.set_original_state() + + vals_to_include = { + "hostname", + "default_gateway", + "operating_state", + "revealed_to_red", + "start_up_duration", + "start_up_countdown", + "shut_down_duration", + "shut_down_countdown", + "is_resetting", + "node_scan_duration", + "node_scan_countdown", + "red_scan_countdown", + } + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Reset ARP Cache + self.arp.clear() + + # Reset ICMP + self.icmp.clear() + + # Reset Session Manager + self.session_manager.clear() + + for software in self.software_manager.software.values(): + software.reset_component_for_episode(episode) + + # Reset all Nics + for nic in self.nics.values(): + nic.reset_component_for_episode(episode) + + if episode and self.sys_log: + self.sys_log.current_episode = episode + self.sys_log.setup_logger() + + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will @@ -1005,9 +1083,6 @@ class Node(SimComponent): return rm - def reset_component_for_episode(self, episode: int): - self._init_request_manager() - def _install_system_software(self): """Install System Software - software that is usually provided with the OS.""" pass @@ -1425,99 +1500,3 @@ class Node(SimComponent): if isinstance(item, Service): return item.uuid in self.services return None - - -class Switch(Node): - """A class representing a Layer 2 network switch.""" - - num_ports: int = 24 - "The number of ports on the switch." - switch_ports: Dict[int, SwitchPort] = {} - "The SwitchPorts on the switch." - mac_address_table: Dict[str, SwitchPort] = {} - "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if not self.switch_ports: - self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)} - for port_num, port in self.switch_ports.items(): - port._connected_node = self - port.parent = self - port.port_num = port_num - - def show(self): - """Prints a table of the SwitchPorts on the Switch.""" - table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) - - for port_num, port in self.switch_ports.items(): - table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) - print(table) - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - return { - "uuid": self.uuid, - "num_ports": self.num_ports, # redundant? - "ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}, - "mac_address_table": {mac: port for mac, port in self.mac_address_table.items()}, - } - - def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): - mac_table_port = self.mac_address_table.get(mac_address) - if not mac_table_port: - self.mac_address_table[mac_address] = switch_port - self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") - else: - if mac_table_port != switch_port: - self.mac_address_table.pop(mac_address) - self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") - self._add_mac_table_entry(mac_address, switch_port) - - def forward_frame(self, frame: Frame, incoming_port: SwitchPort): - """ - Forward a frame to the appropriate port based on the destination MAC address. - - :param frame: The Frame to be forwarded. - :param incoming_port: The port number from which the frame was received. - """ - src_mac = frame.ethernet.src_mac_addr - dst_mac = frame.ethernet.dst_mac_addr - self._add_mac_table_entry(src_mac, incoming_port) - - outgoing_port = self.mac_address_table.get(dst_mac) - if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff": - outgoing_port.send_frame(frame) - else: - # If the destination MAC is not in the table, flood to all ports except incoming - for port in self.switch_ports.values(): - if port != incoming_port: - port.send_frame(frame) - - def disconnect_link_from_port(self, link: Link, port_number: int): - """ - Disconnect a given link from the specified port number on the switch. - - :param link: The Link object to be disconnected. - :param port_number: The port number on the switch from where the link should be disconnected. - :raise NetworkError: When an invalid port number is provided or the link does not match the connection. - """ - port = self.switch_ports.get(port_number) - if port is None: - msg = f"Invalid port number {port_number} on the switch" - _LOGGER.error(msg) - raise NetworkError(msg) - - if port._connected_link != link: - msg = f"The link does not match the connection at port number {port_number}" - _LOGGER.error(msg) - raise NetworkError(msg) - - port.disconnect_link() diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index c2a38aba..8e03cfa3 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -52,6 +52,11 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -93,6 +98,18 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.implicit_rule.set_original_state() + vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.implicit_rule.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -638,6 +655,20 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.acl.set_original_state() + vals_to_include = {"num_ports", "route_table"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.arp.clear() + self.acl.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("acl", RequestType(func=self.acl._request_manager)) diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 8e820ec8..c529ed04 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -9,7 +9,7 @@ class Simulation(SimComponent): """Top-level simulation object which holds a reference to all other parts of the simulation.""" network: Network - domain: DomainController + # domain: DomainController def __init__(self, **kwargs): """Initialise the Simulation.""" @@ -21,6 +21,14 @@ class Simulation(SimComponent): super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + self.network.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.network.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # pass through network requests to the network objects diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 9a58c98a..c69f745d 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,6 +38,12 @@ class Application(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ @@ -82,15 +88,6 @@ class Application(IOSoftware): self.sys_log.info(f"Installing Application {self.name}") self.operating_state = ApplicationOperatingState.INSTALLING - def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. - - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 37236e69..12dfc0ac 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,6 +31,13 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"server_ip_address", "server_password", "connected"} + self._original_state.update(self.model_dump(include=vals_to_include)) def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ef9ac0e7..32dd9cd2 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -33,8 +33,15 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) + self.set_original_state() self.run() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} + self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -43,13 +50,6 @@ class WebBrowser(Application): return rm - def do_this(self): - self._init_request_manager() - print(f"Resetting WebBrowser for episode") - - def reset_component_for_episode(self, episode: int): - pass - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. @@ -60,14 +60,7 @@ class WebBrowser(Application): state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. - - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - self.domain_name_ip_address = None - self.latest_response = None + """Reset the original state of the SimComponent.""" def get_webpage(self) -> bool: """ diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index c2faeb10..1539e024 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -34,9 +34,12 @@ class PacketCapture: "The IP address associated with the PCAP logs." self.switch_port_number = switch_port_number "The SwitchPort number." - self._setup_logger() - def _setup_logger(self): + self.current_episode: int = 1 + + self.setup_logger() + + def setup_logger(self): """Set up the logger configuration.""" log_path = self._get_log_path() @@ -75,7 +78,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 360b5e73..8658f155 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -93,6 +93,11 @@ class SessionManager: """ pass + def clear(self): + """Clears the sessions.""" + self.sessions_by_key.clear() + self.sessions_by_uuid.clear() + @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 7ac6df85..41ce8fee 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -31,9 +31,10 @@ class SysLog: :param hostname: The hostname associated with the system logs being recorded. """ self.hostname = hostname - self._setup_logger() + self.current_episode: int = 1 + self.setup_logger() - def _setup_logger(self): + def setup_logger(self): """ Configures the logger for this SysLog instance. @@ -80,7 +81,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index c4e94845..ad9af335 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,6 +24,12 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d7277e1e..616cbedd 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -38,6 +38,23 @@ class DatabaseService(Service): self._db_file: File self._create_db_file() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = { + "password", + "connections", + "backup_server", + "latest_backup_directory", + "latest_backup_file_name", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.connections.clear() + super().reset_component_for_episode(episode) + def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 266ac4f6..c6c3e09a 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,6 +29,17 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"dns_server"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.dns_cache.clear() + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -42,15 +53,6 @@ class DNSClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address): """ Adds a domain name to the DNS Client cache. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 90a350c8..bbeaa62c 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,6 +28,11 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.dns_table.clear() + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -62,15 +67,6 @@ class DNSServer(Service): """ self.dns_table[domain_name] = domain_ip_address - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive( self, payload: Any, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e2b04c15..d519da8e 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -46,6 +46,12 @@ class Service(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} + self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) @@ -73,15 +79,6 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 86a4e4f1..754aa22f 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -19,8 +19,14 @@ class WebServer(Service): _last_response_status_code: Optional[HttpStatusCode] = None + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self._last_response_status_code = None + super().reset_component_for_episode(episode) + @property def last_response_status_code(self) -> HttpStatusCode: + """The latest http response code.""" return self._last_response_status_code @last_response_status_code.setter @@ -41,14 +47,6 @@ class WebServer(Service): state["last_response_status_code"] = ( self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) - - print( - f"" - f"Printing state from Webserver describe func: " - f"val={state['last_response_status_code']}, " - f"type={type(state['last_response_status_code'])}, " - f"Service obj ID={id(self)}" - ) return state def __init__(self, **kwargs): @@ -102,13 +100,6 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code - print( - f"" - f"Printing state from Webserver http request func: " - f"val={self.last_response_status_code}, " - f"type={type(self.last_response_status_code)}, " - f"Service obj ID={id(self)}" - ) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f2627557..413da959 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -89,6 +89,19 @@ class Software(SimComponent): folder: Optional[Folder] = None "The folder on the file system the Software uses." + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "name", + "health_state_actual", + "health_state_visible", + "criticality", + "patching_count", + "scanning_count", + "revealed_to_red", + } + self._original_state = self.model_dump(include=vals_to_include) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -131,16 +144,6 @@ class Software(SimComponent): ) return state - def reset_component_for_episode(self, episode: int): - """ - Resets the software component for a new episode. - - This method should ensure the software is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a - "reset" should be implemented in subclasses. - """ - pass - def set_health_state(self, health_state: SoftwareHealthState) -> None: """ Assign a new health state to this software. @@ -203,6 +206,12 @@ class IOSoftware(Software): port: Port "The port to which the software is connected." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """