Remove reset methods from most classes
This commit is contained in:
@@ -73,11 +73,6 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -34,32 +34,6 @@ class FileSystem(SimComponent):
|
||||
if not self.folders:
|
||||
self.create_folder("root")
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
|
||||
# Move any 'original' folder that have been deleted back to folders
|
||||
original_folder_uuids = self._original_state["original_folder_uuids"]
|
||||
for uuid in original_folder_uuids:
|
||||
if uuid in self.deleted_folders:
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
|
||||
# Now clear all non-original folders created by agent
|
||||
current_folder_uuids = list(self.folders.keys())
|
||||
for uuid in current_folder_uuids:
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
folder.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -49,32 +49,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
|
||||
# Move any 'original' file that have been deleted back to files
|
||||
original_file_uuids = self._original_state["original_file_uuids"]
|
||||
for uuid in original_file_uuids:
|
||||
if uuid in self.deleted_files:
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
|
||||
# Now clear all non-original files created by agent
|
||||
current_file_uuids = list(self.files.keys())
|
||||
for uuid in current_file_uuids:
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
file.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
|
||||
@@ -1015,15 +1015,6 @@ class Node(SimComponent):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
# Reset ARP Cache
|
||||
self.arp.clear()
|
||||
|
||||
# Reset ICMP
|
||||
self.icmp.clear()
|
||||
|
||||
# Reset Session Manager
|
||||
self.session_manager.clear()
|
||||
|
||||
# Reset File System
|
||||
self.file_system.reset_component_for_episode(episode)
|
||||
|
||||
|
||||
@@ -84,9 +84,7 @@ class AccessControlList(SimComponent):
|
||||
implicit_action: ACLAction
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24
|
||||
_default_config: Dict[int, dict] = {}
|
||||
"""Config dict describing how the ACL list should look at episode start"""
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24 # TODO: this ignores the max_acl_rules and assumes it's default
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
if not kwargs.get("implicit_action"):
|
||||
@@ -97,26 +95,6 @@ class AccessControlList(SimComponent):
|
||||
super().__init__(**kwargs)
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -394,12 +372,6 @@ class RouteTable(SimComponent):
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.routes.clear()
|
||||
self.routes = self._original_state["routes_orig"]
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the RouteTable.
|
||||
@@ -1040,7 +1012,18 @@ class Router(Node):
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
|
||||
# Add the router's default ACL rules from the config.
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
for r_num, r_cfg in cfg["acl"].items():
|
||||
new.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
return new
|
||||
|
||||
@@ -31,12 +31,6 @@ class DatabaseClient(Application):
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
self._query_success_tracker.clear()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the ACLRule.
|
||||
|
||||
@@ -49,11 +49,6 @@ class DataManipulationBot(DatabaseClient):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DataManipulationBot"
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -57,11 +57,6 @@ class DoSBot(DatabaseClient, Application):
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -49,11 +49,6 @@ class WebBrowser(Application):
|
||||
super().__init__(**kwargs)
|
||||
self.run()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -72,9 +67,6 @@ class WebBrowser(Application):
|
||||
state["history"] = [hist_item.state() for hist_item in self.history]
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
|
||||
def get_webpage(self, url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Retrieve the webpage.
|
||||
|
||||
@@ -40,12 +40,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._create_db_file()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
Set up the database backup.
|
||||
|
||||
@@ -29,11 +29,6 @@ class DNSClient(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_cache.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
@@ -28,13 +28,6 @@ class DNSServer(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_table.clear()
|
||||
for key, value in self._original_state["dns_table_orig"].items():
|
||||
self.dns_table[key] = value
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
@@ -27,11 +27,6 @@ class FTPClient(FTPServiceABC):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
|
||||
@@ -27,12 +27,6 @@ class FTPServer(FTPServiceABC):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
@@ -49,21 +49,12 @@ class NTPClient(Service):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: IPv4Address = None,
|
||||
dest_port: [Port] = Port.NTP,
|
||||
dest_port: List[Port] = Port.NTP,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""Requests NTP data from NTP server.
|
||||
|
||||
@@ -34,16 +34,6 @@ class NTPServer(Service):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including
|
||||
resetting any stateful properties or statistics, and clearing any message
|
||||
queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
|
||||
@@ -23,11 +23,6 @@ class WebServer(Service):
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Reference in New Issue
Block a user