Merged PR 173: Database service (without networking)

## Summary
- add a database service
- change how `SimComponent` adds actions to allow inheritance of actions
- add service-based actions, like start, stop, pause, and compromise

## Test process
New test cases were added.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have written/updated **design docs** if this PR implements new functionality
- [ ] I have update the **change log**
- [x] I have run **pre-commit** checks for code style

Related work items: #1801
This commit is contained in:
Marek Wolan
2023-09-04 19:41:17 +00:00
committed by Christopher McCarthy
19 changed files with 478 additions and 51 deletions

View File

@@ -49,16 +49,14 @@ snippet demonstrates usage of the ``ActionPermissionValidator``.
name: str
apps = []
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.action_manager.add_action(
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
"reset_factory_settings",
Action(
func = lambda request, context: self.reset_factory_settings(),
validator = GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
),
)
)
def reset_factory_settings(self):

View File

@@ -41,7 +41,9 @@ class Action:
the action can be performed or not.
"""
def __init__(self, func: Callable[[List[str], Dict], None], validator: ActionPermissionValidator) -> None:
def __init__(
self, func: Callable[[List[str], Dict], None], validator: ActionPermissionValidator = AllowAllValidator()
) -> None:
"""
Save the functions that are for this action.
@@ -58,7 +60,8 @@ class Action:
:param func: Function that performs the request.
:type func: Callable[[List[str], Dict], None]
:param validator: Function that checks if the request is authenticated given the context.
:param validator: Function that checks if the request is authenticated given the context. By default, if no
validator is provided, an 'allow all' validator is added which permits all requests.
:type validator: ActionPermissionValidator
"""
self.func: Callable[[List[str], Dict], None] = func
@@ -136,9 +139,31 @@ class SimComponent(BaseModel):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
super().__init__(**kwargs)
self.action_manager: Optional[ActionManager] = None
self._action_manager: ActionManager = self._init_action_manager()
self._parent: Optional["SimComponent"] = None
def _init_action_manager(self) -> ActionManager:
"""
Initialise the action manager for this component.
When using a hierarchy of components, the child classes should call the parent class's _init_action_manager and
add additional actions on top of the existing generic ones.
Example usage for inherited classes:
..code::python
class WebBrowser(Application):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager() # all actions generic to any Application get initialised
am.add_action(...) # initialise any actions specific to the web browser
return am
:return: Actiona manager object belonging to this sim component.
:rtype: ActionManager
"""
return ActionManager()
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -85,16 +85,18 @@ class DomainController(SimComponent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# Action 'account' matches requests like:
# ['account', '<account-uuid>', *account_action]
self.action_manager.add_action(
am.add_action(
"account",
Action(
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
),
)
return am
def describe_state(self) -> Dict:
"""

View File

@@ -211,7 +211,7 @@ class FileSystem(SimComponent):
if file is not None:
return file
def get_folder_by_name(self, folder_name: str) -> FileSystemFolder:
def get_folder_by_name(self, folder_name: str) -> Optional[FileSystemFolder]:
"""
Returns a the first folder with a matching name.

View File

@@ -87,6 +87,14 @@ class FileSystemFileType(str, Enum):
GZ = 31
"Gzip compressed file."
# Database file types
MDF = 32
"MS SQL Server primary database file"
NDF = 33
"MS SQL Server secondary database file"
LDF = 34
"MS SQL Server transaction log"
file_type_sizes_KB = {
FileSystemFileType.UNKNOWN: 0,

View File

@@ -39,15 +39,19 @@ class Network(SimComponent):
"""
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.action_manager.add_action(
self._nx_graph = MultiGraph()
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
"node",
Action(
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
validator=AllowAllValidator(),
),
)
self._nx_graph = MultiGraph()
return am
@property
def routers(self) -> List[Router]:

View File

@@ -1095,3 +1095,135 @@ class Node(SimComponent):
pass
elif frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
def install_service(self, service: Service) -> None:
"""
Install a service on this node.
:param service: Service instance that has not been installed on any node yet.
:type service: Service
"""
if service in self:
_LOGGER.warning(f"Can't add service {service.uuid} to node {self.uuid}. It's already installed.")
return
self.services[service.uuid] = service
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
:param service: Service object that is currently associated with this node.
:type service: Service
"""
if service not in self:
_LOGGER.warning(f"Can't remove service {service.uuid} from node {self.uuid}. It's not installed.")
return
service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services
return None
class Switch(Node):
"""A class representing a Layer 2 network switch."""
num_ports: int = 24
"The number of ports on the switch."
switch_ports: Dict[int, SwitchPort] = {}
"The SwitchPorts on the switch."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port.parent = self
port.port_num = port_num
def show(self):
"""Prints a table of the SwitchPorts on the Switch."""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
for port_num, port in self.switch_ports.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def forward_frame(self, frame: Frame, incoming_port: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame to be forwarded.
:param incoming_port: The port number from which the frame was received.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, incoming_port)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
if port != incoming_port:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.switch_ports.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()

View File

@@ -21,21 +21,23 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
self.action_manager = ActionManager()
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# pass through network actions to the network objects
self.action_manager.add_action(
am.add_action(
"network",
Action(
func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator()
),
)
# pass through domain actions to the domain object
self.action_manager.add_action(
am.add_action(
"domain",
Action(
func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator()
),
)
return am
def describe_state(self) -> Dict:
"""

View File

@@ -1,6 +1,6 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Set
from typing import Any, Dict, Set
from primaite.simulator.system.software import IOSoftware
@@ -53,14 +53,6 @@ class Application(IOSoftware):
)
return state
def apply_action(self, action: List[str]) -> None:
"""
Applies a list of actions to the Application.
:param action: A list of actions to apply.
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.

View File

@@ -0,0 +1,76 @@
from typing import Dict
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.system.services.service import Service
class DatabaseService(Service):
"""Service loosely modelled on Microsoft SQL Server."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def uninstall(self) -> None:
"""
Undo installation procedure.
This method deletes files created when installing the database, and the database folder if it is empty.
"""
super().uninstall()
node: Node = self.parent
node.file_system.delete_file(self.primary_store)
node.file_system.delete_file(self.transaction_log)
if self.secondary_store:
node.file_system.delete_file(self.secondary_store)
if len(self.folder.files) == 0:
node.file_system.delete_folder(self.folder)
def install(self) -> None:
"""Perform first time install on a node, creating necessary files."""
super().install()
assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node"
self._setup_files()
def _setup_files(
self,
db_size: int = 1000,
use_secondary_db_file: bool = False,
secondary_db_size: int = 300,
folder_name: str = "database",
):
"""Set up files that are required by the database on the parent host.
:param db_size: Initial file size of the main database file, defaults to 1000
:type db_size: int, optional
:param use_secondary_db_file: Whether to use a secondary database file, defaults to False
:type use_secondary_db_file: bool, optional
:param secondary_db_size: Size of the secondary db file, defaults to None
:type secondary_db_size: int, optional
:param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database"
:type folder_name: str, optional
"""
# note that this parent.file_system.create_folder call in the future will be authenticated by using permissions
# handler. This permission will be granted based on service account given to the database service.
self.parent: Node
self.folder = self.parent.file_system.create_folder(folder_name)
self.primary_store = self.parent.file_system.create_file(
"db_primary_store", db_size, FileSystemFileType.MDF, folder=self.folder
)
self.transaction_log = self.parent.file_system.create_file(
"db_transaction_log", "1", FileSystemFileType.LDF, folder=self.folder
)
if use_secondary_db_file:
self.secondary_store = self.parent.file_system.create_file(
"db_secondary_store", secondary_db_size, FileSystemFileType.NDF, folder=self.folder
)
else:
self.secondary_store = None

View File

@@ -1,9 +1,13 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager
from primaite.simulator.system.software import IOSoftware
_LOGGER = getLogger(__name__)
class ServiceOperatingState(Enum):
"""Enumeration of Service Operating States."""
@@ -31,6 +35,21 @@ class Service(IOSoftware):
operating_state: ServiceOperatingState
"The current operating state of the Service."
restart_duration: int = 5
"How many timesteps does it take to restart this service."
_restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("stop", Action(func=lambda request, context: self.stop()))
am.add_action("start", Action(func=lambda request, context: self.start()))
am.add_action("pause", Action(func=lambda request, context: self.pause()))
am.add_action("resume", Action(func=lambda request, context: self.resume()))
am.add_action("restart", Action(func=lambda request, context: self.restart()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
am.add_action("enable", Action(func=lambda request, context: self.enable()))
return am
@abstractmethod
def describe_state(self) -> Dict:
@@ -46,14 +65,6 @@ class Service(IOSoftware):
state.update({"operating_state": self.operating_state.name})
return state
def apply_action(self, action: List[str]) -> None:
"""
Applies a list of actions to the Service.
:param action: A list of actions to apply.
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
@@ -86,3 +97,69 @@ class Service(IOSoftware):
:return: True if successful, False otherwise.
"""
pass
def stop(self) -> None:
"""Stop the service."""
_LOGGER.debug(f"Stopping service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.parent.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
def start(self) -> None:
"""Start the service."""
_LOGGER.debug(f"Starting service {self.name}")
if self.operating_state == ServiceOperatingState.STOPPED:
self.parent.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def pause(self) -> None:
"""Pause the service."""
_LOGGER.debug(f"Pausing service {self.name}")
if self.operating_state == ServiceOperatingState.RUNNING:
self.parent.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
def resume(self) -> None:
"""Resume paused service."""
_LOGGER.debug(f"Resuming service {self.name}")
if self.operating_state == ServiceOperatingState.PAUSED:
self.parent.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def restart(self) -> None:
"""Restart running service."""
_LOGGER.debug(f"Restarting service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.parent.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
self.restart_countdown = self.restarting_duration
def disable(self) -> None:
"""Disable the service."""
_LOGGER.debug(f"Disabling service {self.name}")
self.parent.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
def enable(self) -> None:
"""Enable the disabled service."""
_LOGGER.debug(f"Enabling service {self.name}")
if self.operating_state == ServiceOperatingState.DISABLED:
self.parent.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
def apply_timestep(self, timestep: int) -> None:
"""
Apply a single timestep of simulation dynamics to this service.
In this instance, if any multi-timestep processes are currently occurring (such as restarting or installation),
then they are brought one step closer to being finished.
:param timestep: The current timestep number. (Amount of time since simulation episode began)
:type timestep: int
"""
super().apply_timestep(timestep)
if self.operating_state == ServiceOperatingState.RESTARTING:
if self.restart_countdown <= 0:
_LOGGER.debug(f"Restarting finished for service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.restart_countdown -= 1

View File

@@ -1,8 +1,8 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Set
from typing import Any, Dict, Set
from primaite.simulator.core import SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.transmission.transport_layer import Port
@@ -75,6 +75,17 @@ class Software(SimComponent):
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
"compromise",
Action(
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
),
)
am.add_action("scan", Action(func=lambda request, context: self.scan()))
return am
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -98,17 +109,6 @@ class Software(SimComponent):
)
return state
def apply_action(self, action: List[str]) -> None:
"""
Applies a list of actions to the software.
The specifics of how these actions are applied should be implemented in subclasses.
:param action: A list of actions to apply.
:type action: List[str]
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Resets the software component for a new episode.
@@ -119,6 +119,43 @@ class Software(SimComponent):
"""
pass
def set_health_state(self, health_state: SoftwareHealthState) -> None:
"""
Assign a new health state to this software.
Note: this should only be possible when the software is currently running, but the software base class has no
operating state, only subclasses do. So subclasses will need to implement this check. TODO: check if this should
be changed so that the base Software class has a running attr.
:param health_state: New health state to assign to the software
:type health_state: SoftwareHealthState
"""
self.health_state_actual = health_state
@abstractmethod
def install(self) -> None:
"""
Perform first-time setup of this service on a node.
This is an abstract class that should be overwritten by specific applications or services. It must be called
after the service is already associate with a node. For example, a service may need to authenticate with a
server during installation, or create files in the node's filesystem.
"""
pass
def uninstall(self) -> None:
"""Uninstall this service from a node.
This is an abstract class that should be overwritten by applications or services. It must be called after the
`install` method has already been run on that node. It should undo any installation steps, for example by
deleting files, or contacting a server.
"""
pass
def scan(self) -> None:
"""Update the observed health status to match the actual health status."""
self.health_state_visible = self.health_state_actual
class IOSoftware(Software):
"""

View File

@@ -0,0 +1,56 @@
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
def test_installing_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)
node = Node(hostname="db-server")
node.install_service(db)
assert db in node
file_exists = False
for folder in node.file_system.folders.values():
for file in folder.files.values():
if file.name == "db_primary_store":
file_exists = True
break
if file_exists:
break
assert file_exists
def test_uninstalling_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)
node = Node(hostname="db-server")
node.install_service(db)
node.uninstall_service(db)
assert db not in node
assert node.file_system.get_folder_by_name("database") is None

View File

@@ -45,7 +45,7 @@ def test_nic_deserialize():
nic_json = nic.model_dump_json()
deserialized_nic = NIC.model_validate_json(nic_json)
assert nic == deserialized_nic
assert nic_json == deserialized_nic.model_dump_json()
def test_nic_ip_address_as_network_address_fails():

View File

@@ -0,0 +1,17 @@
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
def test_creation():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)

View File

@@ -43,4 +43,5 @@ class TestIsolatedSimComponent:
comp = TestComponent(name="computer", size=(5, 10))
dump = comp.model_dump_json()
assert comp == TestComponent.model_validate_json(dump)
reconstructed = TestComponent.model_validate_json(dump)
assert dump == reconstructed.model_dump_json()