Merged PR 256: Align the database backup and corruption processes with IYak

## Summary
**Changed:**
- Copying a file via FTP also copies its health status
- The database automatically attempts to make a backup on step 1
- make the db file a property that is fetched by name from the file system instead of a handle to a file (bruh)
- fixed ftp server re-sending requests back to the client
- fix issue where links with >100% bandwidth cause the observation space to crash
- fix issue where starting a node didn't start services. (not sure how that one passed tests previously)

**To align with Yak:**
- database service removed from uc2 observation space
- sql attack affects the file health status instead of the service
- when the web server fails to fetch data, it goes into compromised state until a successful data fetch

## Test process
Notebooks

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2176
This commit is contained in:
Marek Wolan
2024-01-11 10:54:25 +00:00
19 changed files with 103 additions and 70 deletions

View File

@@ -112,10 +112,8 @@ agents:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- service_ref: web_server_web_service
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:

View File

@@ -205,12 +205,15 @@ class LinkObservation(AbstractObservation):
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 10) + 1
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": utilisation_category}}
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
@@ -555,7 +558,7 @@ class NodeObservation(AbstractObservation):
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]

View File

@@ -23,7 +23,6 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = self.game.rl_agents[0]
self.flatten_obs: bool = False
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""

View File

@@ -1271,8 +1271,8 @@ class Node(SimComponent):
self.start_up_countdown = self.start_up_duration
if self.start_up_duration <= 0:
self._start_up_actions()
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic._connected_link:

View File

@@ -38,9 +38,6 @@ class Application(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
@@ -95,6 +92,9 @@ class Application(IOSoftware):
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
# set software health state to GOOD if initially set to UNUSED
if self.health_state_actual == SoftwareHealthState.UNUSED:
self.set_health_state(SoftwareHealthState.GOOD)
def _application_loop(self):
"""The main application loop."""

View File

@@ -4,6 +4,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
@@ -24,7 +26,7 @@ class DatabaseService(Service):
password: Optional[str] = None
connections: Dict[str, datetime] = {}
backup_server: IPv4Address = None
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
@@ -38,7 +40,6 @@ class DatabaseService(Service):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
def set_original_state(self):
@@ -48,7 +49,7 @@ class DatabaseService(Service):
vals_to_include = {
"password",
"connections",
"backup_server",
"backup_server_ip",
"latest_backup_directory",
"latest_backup_file_name",
}
@@ -66,7 +67,7 @@ class DatabaseService(Service):
:param: backup_server_ip: The IP address of the backup server
"""
self.backup_server = backup_server
self.backup_server_ip = backup_server
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
@@ -75,7 +76,7 @@ class DatabaseService(Service):
return False
# check if the backup server was configured
if self.backup_server is None:
if self.backup_server_ip is None:
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
@@ -84,9 +85,9 @@ class DatabaseService(Service):
# send backup copy of database file to FTP server
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self.folder.name,
dest_ip_address=self.backup_server_ip,
src_file_name=self.db_file.name,
src_folder_name="database",
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
)
@@ -112,7 +113,7 @@ class DatabaseService(Service):
src_file_name="database.db",
dest_folder_name="downloads",
dest_file_name="database.db",
dest_ip_address=self.backup_server,
dest_ip_address=self.backup_server_ip,
)
if not response:
@@ -120,13 +121,10 @@ class DatabaseService(Service):
return False
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self.file_system.delete_file(folder_name="database", file_name="downloads.db")
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
if self._db_file is None:
if self.db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
@@ -136,8 +134,17 @@ class DatabaseService(Service):
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
self.file_system.create_file(folder_name="database", file_name="database.db")
@property
def db_file(self) -> File:
"""Returns the database file."""
return self.file_system.get_file(folder_name="database", file_name="database.db")
@property
def folder(self) -> Folder:
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
@@ -170,16 +177,13 @@ class DatabaseService(Service):
"""
self.sys_log.info(f"{self.name}: Running {query}")
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.health_state_actual = SoftwareHealthState.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
# Invalid query
return {"status_code": 500, "data": False}
@@ -233,3 +237,13 @@ class DatabaseService(Service):
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200
def apply_timestep(self, timestep: int) -> None:
"""
Apply a single timestep of simulation dynamics to this service.
Here at the first step, the database backup is created, in addition to normal service update logic.
"""
if timestep == 1:
self.backup_database()
return super().apply_timestep(timestep)

View File

@@ -106,5 +106,5 @@ class FTPServer(FTPServiceABC):
if payload.status_code is not None:
return False
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
self._process_ftp_command(payload=payload, session_id=session_id)
return True

View File

@@ -1,7 +1,7 @@
import shutil
from abc import ABC
from ipaddress import IPv4Address
from typing import Optional
from typing import Dict, Optional
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
@@ -16,6 +16,10 @@ class FTPServiceABC(Service, ABC):
Contains shared methods between both classes.
"""
def describe_state(self) -> Dict:
"""Returns a Dict of the FTPService state."""
return super().describe_state()
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
@@ -52,10 +56,12 @@ class FTPServiceABC(Service, ABC):
folder_name = payload.ftp_command_args["dest_folder_name"]
file_size = payload.ftp_command_args["file_size"]
real_file_path = payload.ftp_command_args.get("real_file_path")
health_status = payload.ftp_command_args["health_status"]
is_real = real_file_path is not None
file = self.file_system.create_file(
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
)
file.health_status = health_status
self.sys_log.info(
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
f"{payload.ftp_command_args['dest_file_name']}"
@@ -110,6 +116,7 @@ class FTPServiceABC(Service, ABC):
"dest_file_name": dest_file_name,
"file_size": file.sim_size,
"real_file_path": file.sim_path if file.real else None,
"health_status": file.health_status,
},
packet_payload_size=file.sim_size,
status_code=FTPStatusCode.OK if is_response else None,

View File

@@ -84,7 +84,7 @@ class DataManipulationBot(DatabaseClient):
payload: Optional[str] = None,
port_scan_p_of_success: float = 0.1,
data_manipulation_p_of_success: float = 0.1,
repeat: bool = False,
repeat: bool = True,
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.

View File

@@ -1,3 +1,4 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
@@ -77,9 +78,6 @@ class Service(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
@@ -98,6 +96,7 @@ class Service(IOSoftware):
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
return rm
@abstractmethod
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -118,7 +117,6 @@ class Service(IOSoftware):
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
self.health_state_actual = SoftwareHealthState.UNUSED
def start(self, **kwargs) -> None:
"""Start the service."""
@@ -129,42 +127,39 @@ class Service(IOSoftware):
if self.operating_state == ServiceOperatingState.STOPPED:
self.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
# set software health state to GOOD if initially set to UNUSED
if self.health_state_actual == SoftwareHealthState.UNUSED:
self.set_health_state(SoftwareHealthState.GOOD)
def pause(self) -> None:
"""Pause the service."""
if self.operating_state == ServiceOperatingState.RUNNING:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def resume(self) -> None:
"""Resume paused service."""
if self.operating_state == ServiceOperatingState.PAUSED:
self.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
def restart(self) -> None:
"""Restart running service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
self.health_state_actual = SoftwareHealthState.OVERWHELMED
self.restart_countdown = self.restart_duration
def disable(self) -> None:
"""Disable the service."""
self.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def enable(self) -> None:
"""Enable the disabled service."""
if self.operating_state == ServiceOperatingState.DISABLED:
self.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def apply_timestep(self, timestep: int) -> None:
"""
@@ -181,5 +176,4 @@ class Service(IOSoftware):
if self.restart_countdown <= 0:
_LOGGER.debug(f"Restarting finished for service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
self.restart_countdown -= 1

View File

@@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -123,7 +124,10 @@ class WebServer(Service):
# get all users
if db_client.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception:

View File

@@ -69,9 +69,9 @@ class Software(SimComponent):
name: str
"The name of the software."
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
"The actual health state of the software."
health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
"The health state of the software visible to the red agent."
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
"The criticality level of the software."
@@ -278,7 +278,7 @@ class IOSoftware(Software):
Returns true if the software can perform actions.
"""
if self.software_manager and self.software_manager.node.operating_state is not NodeOperatingState.ON:
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
return False
return True

View File

@@ -40,6 +40,9 @@ from primaite.simulator.network.hardware.base import Link, Node
class TestService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
kwargs["name"] = "TestService"
kwargs["port"] = Port.HTTP
@@ -60,7 +63,7 @@ class TestApplication(Application):
super().__init__(**kwargs)
def describe_state(self) -> Dict:
pass
return super().describe_state()
@pytest.fixture(scope="function")

View File

@@ -22,7 +22,7 @@ def test_data_manipulation(uc2_network):
assert db_client.query("SELECT")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
db_manipulation_bot.attack()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT")

View File

@@ -24,8 +24,8 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
return app, computer
def test_service_on_offline_node(application_class):
"""Test to check that the service cannot be interacted with when node it is on is off."""
def test_application_on_offline_node(application_class):
"""Test to check that the application cannot be interacted with when node it is on is off."""
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
@@ -49,8 +49,8 @@ def test_service_on_offline_node(application_class):
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_server_turns_off_service(populated_node):
"""Check that the service is turned off when the server is turned off"""
def test_server_turns_off_application(populated_node):
"""Check that the application is turned off when the server is turned off"""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
@@ -65,8 +65,8 @@ def test_server_turns_off_service(populated_node):
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
"""Check that the service cannot be started when the server is off."""
def test_application_cannot_be_turned_on_when_server_is_off(populated_node):
"""Check that the application cannot be started when the server is off."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
@@ -86,8 +86,8 @@ def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_server_turns_on_service(populated_node):
"""Check that turning on the server turns on service."""
def test_server_turns_on_application(populated_node):
"""Check that turning on the server turns on application."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
@@ -109,13 +109,14 @@ def test_server_turns_on_service(populated_node):
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING
computer.start_up_duration = 0
computer.shut_down_duration = 0
computer.power_off()
for i in range(computer.start_up_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
computer.power_on()
for i in range(computer.start_up_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING

View File

@@ -117,13 +117,14 @@ def test_server_turns_on_service(populated_node):
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING
server.start_up_duration = 0
server.shut_down_duration = 0
server.power_off()
for i in range(server.start_up_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
server.power_on()
for i in range(server.start_up_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
import pytest
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -42,6 +43,7 @@ def test_ftp_client_store_file(ftp_client):
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
"health_status": FileSystemItemHealthStatus.GOOD,
},
packet_payload_size=24,
status_code=FTPStatusCode.OK,

View File

@@ -1,5 +1,6 @@
import pytest
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
@@ -41,6 +42,7 @@ def test_ftp_server_store_file(ftp_server):
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
"health_status": FileSystemItemHealthStatus.GOOD,
},
packet_payload_size=24,
)

View File

@@ -1,5 +1,6 @@
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
@@ -15,7 +16,11 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
@pytest.fixture(scope="function")
def web_server() -> Server:
node = Server(
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="web_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
node.software_manager.install(software_class=WebServer)
node.software_manager.software.get("WebServer").start()