Merged PR 226: Episode Reset

## Summary
After initial setup, get sim components to save their state. Then, during episode reset, set attributes to original values.

## Test process
Test script that checks for web client ability to resolve requests after a reset. Formal tests don't exist.

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #1859, #2087
This commit is contained in:
Marek Wolan
2023-11-30 09:10:06 +00:00
17 changed files with 133 additions and 39 deletions

View File

@@ -77,12 +77,14 @@ class File(FileSystemItemABC):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
super().reset_component_for_episode(episode)
@property

View File

@@ -37,16 +37,20 @@ class FileSystem(SimComponent):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
for folder in self.folders.values():
folder.set_original_state()
super().set_original_state()
# Capture a list of all 'original' file uuids
self._original_state["original_folder_uuids"] = list(self.folders.keys())
original_keys = list(self.folders.keys())
vals_to_include = {"sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_folder_uuids"] = original_keys
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state.pop("original_folder_uuids")
original_folder_uuids = self._original_state["original_folder_uuids"]
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
self.folders[uuid] = self.deleted_folders.pop(uuid)

View File

@@ -53,6 +53,7 @@ class Folder(FileSystemItemABC):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
for file in self.files.values():
file.set_original_state()
super().set_original_state()
@@ -69,8 +70,9 @@ class Folder(FileSystemItemABC):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state.pop("original_file_uuids")
original_file_uuids = self._original_state["original_file_uuids"]
for uuid in original_file_uuids:
if uuid in self.deleted_files:
self.files[uuid] = self.deleted_files.pop(uuid)

View File

@@ -12,6 +12,8 @@ from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
@@ -57,6 +59,18 @@ class Network(SimComponent):
for link in self.links.values():
link.reset_component_for_episode(episode)
for node in self.nodes.values():
node.power_on()
for nic in node.nics.values():
nic.enable()
# Reset software
for software in node.software_manager.software.values():
if isinstance(software, Service):
software.start()
elif isinstance(software, Application):
software.run()
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()

View File

@@ -1019,6 +1019,8 @@ class Node(SimComponent):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
# Reset ARP Cache
self.arp.clear()
@@ -1028,10 +1030,6 @@ class Node(SimComponent):
# Reset Session Manager
self.session_manager.clear()
# Reset software
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
# Reset File System
self.file_system.reset_component_for_episode(episode)
@@ -1039,13 +1037,13 @@ class Node(SimComponent):
for nic in self.nics.values():
nic.reset_component_for_episode(episode)
#
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.

View File

@@ -678,8 +678,9 @@ class Router(Node):
"""Sets the original state."""
self.acl.set_original_state()
self.route_table.set_original_state()
super().set_original_state()
vals_to_include = {"num_ports"}
self._original_state = self.model_dump(include=vals_to_include)
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""

View File

@@ -35,10 +35,17 @@ class DatabaseClient(Application):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected"}
vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
self._query_success_tracker.clear()
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -188,4 +195,6 @@ class DatabaseClient(Application):
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
_LOGGER.debug(f"Received payload {payload}")
else:
self.connected = False
return True

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from urllib.parse import urlparse
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
@@ -14,6 +15,8 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.dns.dns_client import DNSClient
_LOGGER = getLogger(__name__)
class WebBrowser(Application):
"""
@@ -43,10 +46,16 @@ class WebBrowser(Application):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -67,7 +76,7 @@ class WebBrowser(Application):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
def get_webpage(self) -> bool:
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.
@@ -76,7 +85,7 @@ class WebBrowser(Application):
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = self.target_url
url = url or self.target_url
if not self._can_perform_action():
return False

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -10,6 +11,8 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
class DatabaseService(Service):
"""
@@ -40,6 +43,7 @@ class DatabaseService(Service):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"password",
@@ -52,6 +56,7 @@ class DatabaseService(Service):
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
print("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)

View File

@@ -31,6 +31,7 @@ class DNSClient(Service):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
@@ -53,15 +54,6 @@ class DNSClient(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> bool:
"""
Adds a domain name to the DNS Client cache.

View File

@@ -30,19 +30,17 @@ class DNSServer(Service):
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_table"}
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
print("dns reset")
print("DNSServer original state", self._original_state)
self.dns_table.clear()
for key, value in self._original_state["dns_table_orig"].items():
self.dns_table[key] = value
super().reset_component_for_episode(episode)
self.show()
def describe_state(self) -> Dict:
"""

View File

@@ -1,6 +1,7 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -9,6 +10,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__)
class FTPClient(FTPServiceABC):
"""
@@ -28,6 +31,18 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -1,12 +1,15 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__)
class FTPServer(FTPServiceABC):
"""
@@ -29,6 +32,19 @@ class FTPServer(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_password"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -2,11 +2,14 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class DataManipulationAttackStage(IntEnum):
"""
@@ -47,6 +50,26 @@ class DataManipulationBot(DatabaseClient):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"server_ip_address",
"payload",
"server_password",
"port_scan_p_of_success",
"data_manipulation_p_of_success",
"attack_stage",
"repeat",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from urllib.parse import urlparse
from primaite import getLogger
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
@@ -13,26 +14,26 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
_last_response_status_code: Optional[HttpStatusCode] = None
last_response_status_code: Optional[HttpStatusCode] = None
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"last_response_status_code"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self._last_response_status_code = None
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
@property
def last_response_status_code(self) -> HttpStatusCode:
"""The latest http response code."""
return self._last_response_status_code
@last_response_status_code.setter
def last_response_status_code(self, val: Any):
self._last_response_status_code = val
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -76,6 +76,10 @@ class TestPrimaiteSession:
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.skip(
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
"reset doesn't implement software restore."
)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:

View File

@@ -27,10 +27,11 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network):
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
web_client.target_url = "http://arcd.com/users/"
assert web_client.get_webpage() is True
# latest reponse should have status code 200
# latest response should have status code 200
assert web_client.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK