diff --git a/docs/source/action_system.rst b/docs/source/action_system.rst index 11b74abf..88baf232 100644 --- a/docs/source/action_system.rst +++ b/docs/source/action_system.rst @@ -27,7 +27,7 @@ Just like other aspects of SimComponent, the actions are not managed centrally f 4. ``Service`` receives ``['restart']``. Since ``restart`` is a defined action in the service's own RequestManager, the service performs a restart. -Techincal Detail +Technical Detail ================ This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.Action`, and :py:class:`primaite.simulator.core.RequestManager`. @@ -35,12 +35,12 @@ This system was achieved by implementing two classes, :py:class:`primaite.simula Action ------ -The ``Action`` object stores a reference to a method that performs the action, for example a node could have an action that stores a reference to ``self.turn_on()``. Techincally, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``Action`` object can also hold a validator that will permit/deny the action depending on context. +The ``Action`` object stores a reference to a method that performs the action, for example a node could have an action that stores a reference to ``self.turn_on()``. Technically, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``Action`` object can also hold a validator that will permit/deny the action depending on context. RequestManager ------------- -The ``RequestManager`` object stores a mapping between strings and actions. It is responsible for processing the ``request`` and passing it down the ownership tree. Techincally, the ``RequestManager`` is itself a callable that accepts `request, context` tuple, and so it can be chained with other action managers. +The ``RequestManager`` object stores a mapping between strings and actions. It is responsible for processing the ``request`` and passing it down the ownership tree. Technically, the ``RequestManager`` is itself a callable that accepts `request, context` tuple, and so it can be chained with other action managers. A simple example without chaining can be seen in the :py:class:`primaite.simulator.file_system.file_system.File` class. @@ -50,9 +50,9 @@ A simple example without chaining can be seen in the :py:class:`primaite.simulat ... def _init_request_manager(self): ... - request_manager.add_action("scan", Action(func=lambda request, context: self.scan())) - request_manager.add_action("repair", Action(func=lambda request, context: self.repair())) - request_manager.add_action("restore", Action(func=lambda request, context: self.restore())) + request_manager.add_request("scan", Action(func=lambda request, context: self.scan())) + request_manager.add_request("repair", Action(func=lambda request, context: self.repair())) + request_manager.add_request("restore", Action(func=lambda request, context: self.restore())) *ellipses (``...``) used to omit code impertinent to this explanation* @@ -70,7 +70,7 @@ An example of how this works is in the :py:class:`primaite.simulator.network.har def _init_request_manager(self): ... # a regular action which is processed by the Node itself - request_manager.add_action("turn_on", Action(func=lambda request, context: self.turn_on())) + request_manager.add_request("turn_on", Action(func=lambda request, context: self.turn_on())) # if the Node receives a request where the first word is 'service', it will use a dummy manager # called self._service_request_manager to pass on the reqeust to the relevant service. This dummy @@ -78,11 +78,11 @@ An example of how this works is in the :py:class:`primaite.simulator.network.har # done because the next string after "service" is always the uuid of that service, so we need an # RequestManager to pop that string before sending it onto the relevant service's RequestManager. self._service_request_manager = RequestManager() - request_manager.add_action("service", Action(func=self._service_request_manager)) + request_manager.add_request("service", Action(func=self._service_request_manager)) ... def install_service(self, service): self.services[service.uuid] = service ... # Here, the service UUID is registered to allow passing actions between the node and the service. - self._service_request_manager.add_action(service.uuid, Action(func=service._request_manager)) + self._service_request_manager.add_request(service.uuid, Action(func=service._request_manager)) diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 1dbf9dec..0801c79e 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -110,11 +110,9 @@ Clone & Install PrimAITE for Development To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location of your choice: -.. TODO:: Add repo path once we know what it is - .. code-block:: bash - git clone + git clone https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE cd primaite Create and activate your Python virtual environment (venv) diff --git a/docs/source/simulation_components/network/base_hardware.rst b/docs/source/simulation_components/network/base_hardware.rst index 452667d2..af4ec26c 100644 --- a/docs/source/simulation_components/network/base_hardware.rst +++ b/docs/source/simulation_components/network/base_hardware.rst @@ -2,20 +2,24 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +############# Base Hardware -============= +############# The physical layer components are models of a NIC (Network Interface Card), SwitchPort, Node, Switch, and a Link. These components allow modelling of layer 1 (physical layer) in the OSI model and the nodes that connect to and transmit across layer 1. +=== NIC -### +=== + The NIC class provides a realistic model of a Network Interface Card. The NIC acts as the interface between a Node and a Link, handling IP and MAC addressing, status, and sending/receiving frames. +---------- Addressing -********** +---------- A NIC has both an IPv4 address and MAC address assigned: @@ -24,8 +28,10 @@ A NIC has both an IPv4 address and MAC address assigned: - **gateway** - The default gateway IP address for routing traffic beyond the local network. - **mac_address** - A unique MAC address assigned to the NIC by the manufacturer. + +------ Status -****** +------ The status of the NIC is represented by: @@ -33,14 +39,17 @@ The status of the NIC is represented by: - **connected_node** - The Node instance the NIC is attached to. - **connected_link** - The Link instance the NIC is wired to. + +-------------- Packet Capture -************** +-------------- - **pcap** - A PacketCapture instance attached to the NIC for capturing all frames sent and received. This allows packet capture and analysis. +------------------------ Sending/Receiving Frames -************************ +------------------------ The NIC can send and receive Frames to/from the connected Link: @@ -50,8 +59,9 @@ The NIC can send and receive Frames to/from the connected Link: This allows a NIC to handle sending, receiving, and forwarding of network traffic at layer 2 of the OSI model. The Frames contain network data encapsulated with various protocol headers. +----------- Basic Usage -*********** +----------- .. code-block:: python @@ -64,8 +74,9 @@ Basic Usage frame = Frame(...) nic1.send_frame(frame) +========== SwitchPort -########## +========== The SwitchPort models a port on a network switch. It has similar attributes and methods to NIC for addressing, status, packet capture, sending/receiving frames, etc. @@ -75,26 +86,47 @@ Key attributes: - **port_num**: The port number on the switch. - **connected_switch**: The switch to which this port belongs. +==== Node -#### +==== The Node class represents a base node that communicates on the Network. +Nodes take more than 1 time step to power on (3 time steps by default). +To create a Node that is already powered on, the Node's operating state can be overriden. +Otherwise, the node ``start_up_duration`` (and ``shut_down_duration``) can be set to 0 if +the node will be powered off or on multiple times. This will still need ``power_on()`` to +be called to turn the node on. + +e.g. + +.. code-block:: python + + active_node = Node(hostname='server1', operating_state=NodeOperatingState.ON) + # node is already on, no need to call power_on() + + + instant_start_node = Node(hostname="client", start_up_duration=0, shut_down_duration=0) + instant_start_node.power_on() # node will still need to be powered on + +------------------ Network Interfaces -****************** +------------------ A Node will typically have one or more NICs attached to it for network connectivity: - **nics** - A dictionary containing the NIC instances attached to the Node. NICs can be added/removed. +------------- Configuration -************* +------------- - **hostname** - Configured hostname of the Node. - **operating_state** - Current operating state like ON or OFF. The NICs will be enabled/disabled based on this. +---------------- Network Services -**************** +---------------- A Node runs various network services and components for handling traffic: @@ -110,8 +142,9 @@ The SysLog records informational, warning, and error events that occur on the No debugging and tracing program execution and network activity for each simulated Node. Other Node services like ARP and ICMP, along with custom Applications, services, and Processes will log to the SysLog. +----------------- Sending/Receiving -***************** +----------------- The Node handles sending and receiving Frames via its attached NICs: @@ -119,8 +152,9 @@ The Node handles sending and receiving Frames via its attached NICs: - **receive_frame()** - Receives a Frame from the network through a NIC. The Node then processes it appropriately based on the protocols and payload. +----------- Basic Usage -*********** +----------- .. code-block:: python @@ -137,15 +171,16 @@ Basic Usage The Node class brings together the NICs, configuration, and services to model a full network node that can send, receive, process, and forward traffic on a simulated network. - +====== Switch -###### +====== The Switch subclass models a network switch. It inherits from Node and acts at layer 2 of the OSI model to forward frames based on MAC addresses. +-------------------------- Inherits Node Capabilities -************************** +-------------------------- Since Switch subclasses Node, it inherits all capabilities from Node like: @@ -154,16 +189,18 @@ Since Switch subclasses Node, it inherits all capabilities from Node like: - **Sending and receiving frames** - **Maintaining system logs** +----- Ports -***** +----- A Switch has multiple ports implemented using SwitchPort instances: - **switch_ports** - A dictionary mapping port numbers to SwitchPort instances. - **num_ports** - The number of ports the Switch has. +---------- Forwarding -********** +---------- A Switch forwards frames between ports based on the destination MAC: @@ -179,21 +216,24 @@ When a frame is received on a SwitchPort: This allows the Switch to dynamically build up a mapping table between MAC addresses and SwitchPorts based on traffic received. If no entry exists for a destination MAC, it floods the frame out all ports. +==== Link -#### +==== The Link class represents a physical link or connection between two network endpoints like NICs or SwitchPorts. +--------- Endpoints -********* +--------- A Link connects two endpoints: - **endpoint_a** - The first endpoint, a NIC or SwitchPort. - **endpoint_b** - The second endpoint, a NIC or SwitchPort. +------------ Transmission -************ +------------ Links transmit Frames between the endpoints: @@ -201,8 +241,9 @@ Links transmit Frames between the endpoints: Uses bandwidth/load properties to determine if transmission is possible. +---------------- Bandwidth & Load -**************** +---------------- - **bandwidth** - The total capacity of the Link in Mbps. - **current_load** - The current bandwidth utilization of the Link in Mbps. @@ -210,16 +251,18 @@ Bandwidth & Load As Frames are sent over the Link, the load increases. The Link tracks if there is enough unused capacity to transmit a Frame based on its size and the current load. +------ Status -****** +------ - **up** - Boolean indicating if the Link is currently up/active based on the endpoint status. - **endpoint_up()/down()** - Notifies the Link when an endpoint goes up or down. This allows the Link to realistically model the connection and transmission characteristics between two endpoints. +======================= Putting it all Together -####################### +======================= We'll now demonstrate how the nodes, NICs, switches, and links connect in a network, including full code examples and syslog extracts to illustrate the step-by-step process. @@ -230,35 +273,33 @@ PC's and two switches. .. image:: ../../../_static/four_node_two_switch_network.png +------------------- Create Nodes & NICs -******************* +------------------- First, we'll create the four nodes, each with a single NIC. .. code-block:: python - pc_a = Node(hostname="pc_a") + from primaite.simulator.network.hardware.base import Node, NodeOperatingState, NIC + + pc_a = Node(hostname="pc_a", operating_state=NodeOperatingState.ON) nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") pc_a.connect_nic(nic_a) - pc_a.power_on() - pc_b = Node(hostname="pc_b") + pc_b = Node(hostname="pc_b", operating_state=NodeOperatingState.ON) nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") pc_b.connect_nic(nic_b) - pc_b.power_on() - pc_c = Node(hostname="pc_c") + pc_c = Node(hostname="pc_c", operating_state=NodeOperatingState.ON) nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0", gateway="192.168.0.1") pc_c.connect_nic(nic_c) - pc_c.power_on() - pc_d = Node(hostname="pc_d") + pc_d = Node(hostname="pc_d", operating_state=NodeOperatingState.ON) nic_d = NIC(ip_address="192.168.0.13", subnet_mask="255.255.255.0", gateway="192.168.0.1") pc_d.connect_nic(nic_d) - pc_d.power_on() - -This produces: +Creating the four nodes results in: **node_a NIC table** @@ -273,7 +314,6 @@ This produces: .. code-block:: 2023-08-08 15:50:08,355 INFO: Connected NIC 80:af:f2:f6:58:b7/192.168.0.10 - 2023-08-08 15:50:08,355 INFO: Turned on **node_b NIC table** @@ -288,7 +328,6 @@ This produces: .. code-block:: 2023-08-08 15:50:08,357 INFO: Connected NIC 98:ad:eb:7c:dc:cb/192.168.0.11 - 2023-08-08 15:50:08,357 INFO: Turned on **node_c NIC table** @@ -303,7 +342,6 @@ This produces: .. code-block:: 2023-08-08 15:50:08,358 INFO: Connected NIC bc:72:82:5d:82:a4/192.168.0.12 - 2023-08-08 15:50:08,358 INFO: Turned on **node_d NIC table** @@ -318,21 +356,19 @@ This produces: .. code-block:: 2023-08-08 15:50:08,359 INFO: Connected NIC 84:20:7c:ec:a5:c6/192.168.0.13 - 2023-08-08 15:50:08,360 INFO: Turned on +--------------- Create Switches -*************** +--------------- Next, we'll create two six-port switches: .. code-block:: python - switch_1 = Switch(hostname="switch_1", num_ports=6) - switch_1.power_on() + switch_1 = Switch(hostname="switch_1", num_ports=6, operating_state=NodeOperatingState.ON) - switch_2 = Switch(hostname="switch_2", num_ports=6) - switch_2.power_on() + switch_2 = Switch(hostname="switch_2", num_ports=6, operating_state=NodeOperatingState.ON) This produces: @@ -384,8 +420,9 @@ This produces: 2023-08-08 15:50:08,374 INFO: Turned on +------------ Create Links -************ +------------ Finally, we'll create the five links that connect the nodes and the switches: @@ -523,8 +560,9 @@ This produces: 2023-08-08 15:50:08,384 INFO: SwitchPort 96:77:39:d1:de:44 enabled +------------ Perform Ping -************ +------------ Now with the network setup and operational, we can perform a ping to confirm that communication between nodes over a switched network is possible. In the below example, we ping 192.168.0.13 (node_d) from node_a: diff --git a/docs/source/simulation_structure.rst b/docs/source/simulation_structure.rst index 20d2d2d3..2f0a56e8 100644 --- a/docs/source/simulation_structure.rst +++ b/docs/source/simulation_structure.rst @@ -51,7 +51,7 @@ snippet demonstrates usage of the ``ActionPermissionValidator``. def _init_request_manager(self) -> RequestManager: am = super()._init_request_manager() - am.add_action( + am.add_request( "reset_factory_settings", Action( func = lambda request, context: self.reset_factory_settings(), diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index eceddfd5..9ead877e 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -173,9 +173,9 @@ class SimComponent(BaseModel): class WebBrowser(Application): def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() # all requests generic to any Application get initialised - am.add_request(...) # initialise any requests specific to the web browser - return am + rm = super()._init_request_manager() # all requests generic to any Application get initialised + rm.add_request(...) # initialise any requests specific to the web browser + return rm :return: Request manager object belonging to this sim component. :rtype: RequestManager diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index 66900327..e9f3b26d 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -80,17 +80,17 @@ class DomainController(SimComponent): super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() # Action 'account' matches requests like: # ['account', '', *account_action] - am.add_request( + rm.add_request( "account", RequestType( func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context), validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]), ), ) - return am + return rm def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index d66f568a..5d0dbedf 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -1,8 +1,11 @@ from __future__ import annotations +import hashlib +import json import math import os.path import shutil +from abc import abstractmethod from enum import Enum from pathlib import Path from typing import Dict, Optional @@ -47,10 +50,19 @@ class FileSystemItemHealthStatus(Enum): """Health status for folders and files.""" GOOD = 1 + """File/Folder is OK.""" + COMPROMISED = 2 + """File/Folder is quarantined.""" + CORRUPT = 3 + """File/Folder is corrupted.""" + RESTORING = 4 + """File/Folder is in the process of being restored.""" + REPAIRING = 5 + """File/Folder is in the process of being repaired.""" class FileSystemItemABC(SimComponent): @@ -64,6 +76,15 @@ class FileSystemItemABC(SimComponent): "The name of the FileSystemItemABC." health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD + health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD + "Actual status of the current FileSystemItem" + + visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD + "Visible status of the current FileSystemItem" + + previous_hash: Optional[str] = None + "Hash of the file contents or the description state" + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -71,9 +92,24 @@ class FileSystemItemABC(SimComponent): :return: Current state of this object and child objects. """ state = super().describe_state() - state.update({"name": self.name, "health_status": self.health_status.value}) + state["name"] = self.name + state["status"] = self.health_status.value + state["visible_status"] = self.visible_health_status.value + state["previous_hash"] = self.previous_hash return state + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan())) + rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash())) + rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair())) + rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore())) + + rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt())) + + return rm + @property def size_str(self) -> str: """ @@ -86,6 +122,39 @@ class FileSystemItemABC(SimComponent): """ return convert_size(self.size) + @abstractmethod + def check_hash(self) -> bool: + """ + Checks the has of the file to detect any changes. + + For current implementation, any change in file hash means it is compromised. + + Return False if corruption is detected, otherwise True + """ + pass + + @abstractmethod + def repair(self) -> bool: + """ + Repair the FileSystemItem. + + True if successfully repaired. False otherwise. + """ + pass + + @abstractmethod + def corrupt(self) -> bool: + """ + Corrupt the FileSystemItem. + + True if successfully corrupted. False otherwise. + """ + pass + + def restore(self) -> None: + """Restore the file/folder to the state before it got ruined.""" + pass + class FileSystem(SimComponent): """Class that contains all the simulation File System.""" @@ -103,15 +172,20 @@ class FileSystem(SimComponent): self.create_folder("root") def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() + + rm.add_request( + name="delete", + request_type=RequestType(func=lambda request, context: self.delete_folder_by_id(folder_uuid=request[0])), + ) self._folder_request_manager = RequestManager() - am.add_request("folder", RequestType(func=self._folder_request_manager)) + rm.add_request("folder", RequestType(func=self._folder_request_manager)) self._file_request_manager = RequestManager() - am.add_request("file", RequestType(func=self._file_request_manager)) + rm.add_request("file", RequestType(func=self._file_request_manager)) - return am + return rm @property def size(self) -> int: @@ -173,7 +247,9 @@ class FileSystem(SimComponent): self.folders[folder.uuid] = folder self._folders_by_name[folder.name] = folder self.sys_log.info(f"Created folder /{folder.name}") - self._folder_request_manager.add_request(folder.uuid, RequestType(func=folder._request_manager)) + self._folder_request_manager.add_request( + name=folder.uuid, request_type=RequestType(func=folder._request_manager) + ) return folder def delete_folder(self, folder_name: str): @@ -187,15 +263,29 @@ class FileSystem(SimComponent): return folder = self._folders_by_name.get(folder_name) if folder: - for file in folder.files.values(): - self.delete_file(file) + # remove from folder list self.folders.pop(folder.uuid) self._folders_by_name.pop(folder.name) self.sys_log.info(f"Deleted folder /{folder.name} and its contents") self._folder_request_manager.remove_request(folder.uuid) + folder.remove_all_files() + else: _LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}") + def delete_folder_by_id(self, folder_uuid: str): + """ + Deletes a folder via its uuid. + + :param: folder_uuid: UUID of the folder to delete + """ + folder = self.get_folder_by_id(folder_uuid=folder_uuid) + self.delete_folder(folder_name=folder.name) + + def restore_folder(self, folder_id: str): + """TODO.""" + pass + def create_file( self, file_name: str, @@ -234,7 +324,7 @@ class FileSystem(SimComponent): ) folder.add_file(file) self.sys_log.info(f"Created file /{file.path}") - self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager)) + self._file_request_manager.add_request(name=file.uuid, request_type=RequestType(func=file._request_manager)) return file def get_file(self, folder_name: str, file_name: str) -> Optional[File]: @@ -309,6 +399,20 @@ class FileSystem(SimComponent): new_file.sim_path.parent.mkdir(exist_ok=True) shutil.copy2(file.sim_path, new_file.sim_path) + def restore_file(self, folder_id: str, file_id: str): + """ + Restore a file. + + Checks the current file's status and applies the correct fix for the file. + + :param: folder_id: id of the folder where the file is stored + :type: folder_id: str + + :param: folder_id: id of the file to restore + :type: folder_id: str + """ + pass + def get_folder(self, folder_name: str) -> Optional[Folder]: """ Get a folder by its name if it exists. @@ -322,7 +426,7 @@ class FileSystem(SimComponent): """ Get a folder by its uuid if it exists. - :param folder_uuid: The folder uuid. + :param: folder_uuid: The folder uuid. :return: The matching Folder. """ return self.folders.get(folder_uuid) @@ -337,20 +441,17 @@ class Folder(FileSystemItemABC): "Files stored in the folder." _files_by_name: Dict[str, File] = {} "Files by their name as .." - is_quarantined: bool = False - "Flag that marks the folder as quarantined if true." + + scan_duration: int = -1 + "How many timesteps to complete a scan." def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() - - am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request - - return am + rm = super()._init_request_manager() + rm.add_request( + name="delete", + request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])), + ) + return rm def describe_state(self) -> Dict: """ @@ -360,7 +461,6 @@ class Folder(FileSystemItemABC): """ state = super().describe_state() state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()} - state["is_quarantined"] = self.is_quarantined return state def show(self, markdown: bool = False): @@ -388,6 +488,27 @@ class Folder(FileSystemItemABC): """ return sum(file.size for file in self.files.values() if file.size is not None) + def apply_timestep(self, timestep: int): + """ + Apply a single timestep of simulation dynamics to this service. + + In this instance, if any multi-timestep processes are currently occurring (such as scanning), + then they are brought one step closer to being finished. + + :param timestep: The current timestep number. (Amount of time since simulation episode began) + :type timestep: int + """ + super().apply_timestep(timestep=timestep) + + # scan files each timestep + if self.scan_duration > -1: + # scan one file per timestep + file = self.get_file_by_id(file_uuid=list(self.files)[self.scan_duration - 1]) + file.scan() + if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: + self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.scan_duration -= 1 + def get_file(self, file_name: str) -> Optional[File]: """ Get a file by its name. @@ -404,7 +525,7 @@ class Folder(FileSystemItemABC): """ Get a file by its uuid. - :param file_uuid: The file uuid. + :param: file_uuid: The file uuid. :return: The matching File. """ return self.files.get(file_uuid) @@ -446,21 +567,121 @@ class Folder(FileSystemItemABC): else: _LOGGER.debug(f"File with UUID {file.uuid} was not found.") + def remove_file_by_id(self, file_uuid: str): + """ + Remove a file using id. + + :param: file_uuid: The UUID of the file to remove. + """ + file = self.get_file_by_id(file_uuid=file_uuid) + self.remove_file(file=file) + + def remove_all_files(self): + """Removes all the files in the folder.""" + self.files = {} + self._files_by_name = {} + + def restore_file(self, file: Optional[File]): + """ + Restores a file. + + The method can take a File object or a file id. + + :param file: The file to remove + """ + pass + def quarantine(self): """Quarantines the File System Folder.""" - if not self.is_quarantined: - self.is_quarantined = True - self.fs.sys_log.info(f"Quarantined folder ./{self.name}") + pass def unquarantine(self): """Unquarantine of the File System Folder.""" - if self.is_quarantined: - self.is_quarantined = False - self.fs.sys_log.info(f"Quarantined folder ./{self.name}") + pass def quarantine_status(self) -> bool: """Returns true if the folder is being quarantined.""" - return self.is_quarantined + pass + + def scan(self) -> None: + """Update Folder visible status.""" + if self.scan_duration <= -1: + # scan one file per timestep + self.scan_duration = len(self.files) + self.fs.sys_log.info(f"Scanning folder {self.name} (id: {self.uuid})") + else: + # scan already in progress + self.fs.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})") + + def check_hash(self) -> bool: + """ + Runs a :func:`check_hash` on all files in the folder. + + If a file under the folder is corrupted, the whole folder is considered corrupted. + + TODO: For now this will just iterate through the files and run :func:`check_hash` and ignores + any other changes to the folder + + Return False if corruption is detected, otherwise True + """ + super().check_hash() + + # iterate through the files and run a check hash + no_corrupted_files = True + + for file_id in self.files: + file = self.get_file_by_id(file_uuid=file_id) + no_corrupted_files = file.check_hash() + + # if one file in the folder is corrupted, set the folder status to corrupted + if not no_corrupted_files: + self.corrupt() + + self.fs.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})") + + return no_corrupted_files + + def repair(self) -> bool: + """Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD.""" + super().repair() + + repaired = False + + # iterate through the files in the folder + for file_id in self.files: + file = self.get_file_by_id(file_uuid=file_id) + repaired = file.repair() + + # set file status to good if corrupt + if self.health_status == FileSystemItemHealthStatus.CORRUPT: + self.health_status = FileSystemItemHealthStatus.GOOD + repaired = True + + self.fs.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})") + return repaired + + def restore(self) -> None: + """TODO.""" + pass + + def corrupt(self) -> bool: + """Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT.""" + super().corrupt() + + corrupted = False + + # iterate through the files in the folder + for file_id in self.files: + file = self.get_file_by_id(file_uuid=file_id) + corrupted = file.corrupt() + + # set file status to corrupt if good + if self.health_status == FileSystemItemHealthStatus.GOOD: + self.health_status = FileSystemItemHealthStatus.CORRUPT + corrupted = True + + self.fs.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})") + return corrupted class File(FileSystemItemABC): @@ -517,18 +738,6 @@ class File(FileSystemItemABC): with open(self.sim_path, mode="a"): pass - def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() - - am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request - am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request - - return am - def make_copy(self, dst_folder: Folder) -> File: """ Create a copy of the current File object in the given destination folder. @@ -564,3 +773,73 @@ class File(FileSystemItemABC): state["size"] = self.size state["file_type"] = self.file_type.name return state + + def scan(self) -> None: + """Updates the visible statuses of the file.""" + path = self.folder.name + "/" + self.name + self.folder.fs.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") + self.visible_health_status = self.health_status + + def check_hash(self) -> bool: + """ + Check if the file has been changed. + + If changed, the file is considered corrupted. + + Return False if corruption is detected, otherwise True + """ + current_hash = None + + # if file is real, read the file contents + if self.real: + with open(self.sim_path, "rb") as f: + file_hash = hashlib.blake2b() + while chunk := f.read(8192): + file_hash.update(chunk) + + current_hash = file_hash.hexdigest() + else: + # otherwise get describe_state dict and hash that + current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest() + + # if the previous hash is None, set the current hash to previous + if self.previous_hash is None: + self.previous_hash = current_hash + + # if the previous hash and current hash do not match, mark file as corrupted + if self.previous_hash is not current_hash: + self.corrupt() + return False + + return True + + def repair(self) -> bool: + """Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD.""" + super().repair() + + # set file status to good if corrupt + if self.health_status == FileSystemItemHealthStatus.CORRUPT: + self.health_status = FileSystemItemHealthStatus.GOOD + + path = self.folder.name + "/" + self.name + self.folder.fs.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") + return True + + def restore(self) -> None: + """Restore a corrupted File by setting the status to FileSystemItemStatus.GOOD.""" + pass + + def corrupt(self) -> bool: + """Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT.""" + super().corrupt() + + corrupted = False + + # set file status to good if corrupt + if self.health_status == FileSystemItemHealthStatus.GOOD: + self.health_status = FileSystemItemHealthStatus.CORRUPT + corrupted = True + + path = self.folder.name + "/" + self.name + self.folder.fs.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") + return corrupted diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index f4590c66..9fbafc29 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -44,13 +44,13 @@ class Network(SimComponent): self._nx_graph = MultiGraph() def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() self._node_request_manager = RequestManager() - am.add_request( + rm.add_request( "node", RequestType(func=self._node_request_manager), ) - return am + return rm @property def routers(self) -> List[Router]: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 607e348b..bdb9b83c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -145,12 +145,12 @@ class NIC(SimComponent): return state def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() - am.add_request("enable", RequestType(func=lambda request, context: self.enable())) - am.add_request("disable", RequestType(func=lambda request, context: self.disable())) + rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) + rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) - return am + return rm @property def ip_network(self) -> IPv4Network: @@ -914,6 +914,18 @@ class Node(SimComponent): revealed_to_red: bool = False "Informs whether the node has been revealed to a red agent." + start_up_duration: int = 3 + "Time steps needed for the node to start up." + + start_up_countdown: int = 0 + "Time steps needed until node is booted up." + + shut_down_duration: int = 3 + "Time steps needed for the node to shut down." + + shut_down_countdown: int = 0 + "Time steps needed until node is shut down." + def __init__(self, **kwargs): """ Initialize the Node with various components and managers. @@ -952,30 +964,30 @@ class Node(SimComponent): def _init_request_manager(self) -> RequestManager: # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will # need a better name and better documentation. - am = super()._init_request_manager() + rm = super()._init_request_manager() # since there are potentially many services, create an request manager that can map service name self._service_request_manager = RequestManager() - am.add_request("service", RequestType(func=self._service_request_manager)) + rm.add_request("service", RequestType(func=self._service_request_manager)) self._nic_request_manager = RequestManager() - am.add_request("nic", RequestType(func=self._nic_request_manager)) + rm.add_request("nic", RequestType(func=self._nic_request_manager)) - am.add_request("file_system", RequestType(func=self.file_system._request_manager)) + rm.add_request("file_system", RequestType(func=self.file_system._request_manager)) # currently we don't have any applications nor processes, so these will be empty self._process_request_manager = RequestManager() - am.add_request("process", RequestType(func=self._process_request_manager)) + rm.add_request("process", RequestType(func=self._process_request_manager)) self._application_request_manager = RequestManager() - am.add_request("application", RequestType(func=self._application_request_manager)) + rm.add_request("application", RequestType(func=self._application_request_manager)) - am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan + rm.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan - am.add_request("shutdown", RequestType(func=lambda request, context: self.power_off())) - am.add_request("startup", RequestType(func=lambda request, context: self.power_on())) - am.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset - am.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request - am.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request + rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off())) + rm.add_request("startup", RequestType(func=lambda request, context: self.power_on())) + rm.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset + rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request + rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request - return am + return rm def _install_system_software(self): """Install System Software - software that is usually provided with the OS.""" @@ -1001,6 +1013,7 @@ class Node(SimComponent): "applications": {uuid: app.describe_state() for uuid, app in self.applications.items()}, "services": {uuid: svc.describe_state() for uuid, svc in self.services.items()}, "process": {uuid: proc.describe_state() for uuid, proc in self.processes.items()}, + "revealed_to_red": self.revealed_to_red, } ) return state @@ -1042,9 +1055,45 @@ class Node(SimComponent): ) print(table) + def apply_timestep(self, timestep: int): + """ + Apply a single timestep of simulation dynamics to this service. + + In this instance, if any multi-timestep processes are currently occurring + (such as starting up or shutting down), then they are brought one step closer to + being finished. + + :param timestep: The current timestep number. (Amount of time since simulation episode began) + :type timestep: int + """ + super().apply_timestep(timestep=timestep) + + # count down to boot up + if self.start_up_countdown > 0: + self.start_up_countdown -= 1 + else: + if self.operating_state == NodeOperatingState.BOOTING: + self.operating_state = NodeOperatingState.ON + self.sys_log.info("Turned on") + for nic in self.nics.values(): + if nic._connected_link: + nic.enable() + + # count down to shut down + if self.shut_down_countdown > 0: + self.shut_down_countdown -= 1 + else: + if self.operating_state == NodeOperatingState.SHUTTING_DOWN: + self.operating_state = NodeOperatingState.OFF + self.sys_log.info("Turned off") + def power_on(self): """Power on the Node, enabling its NICs if it is in the OFF state.""" if self.operating_state == NodeOperatingState.OFF: + self.operating_state = NodeOperatingState.BOOTING + self.start_up_countdown = self.start_up_duration + + if self.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self.sys_log.info("Turned on") for nic in self.nics.values(): @@ -1056,6 +1105,10 @@ class Node(SimComponent): if self.operating_state == NodeOperatingState.ON: for nic in self.nics.values(): nic.disable() + self.operating_state = NodeOperatingState.SHUTTING_DOWN + self.shut_down_countdown = self.shut_down_duration + + if self.shut_down_duration <= 0: self.operating_state = NodeOperatingState.OFF self.sys_log.info("Turned off") @@ -1135,7 +1188,7 @@ class Node(SimComponent): f"Ping statistics for {target_ip_address}: " f"Packets: Sent = {pings}, " f"Received = {request_replies}, " - f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)" + f"Lost = {pings - request_replies} ({(pings - request_replies) / pings * 100}% loss)" ) return passed return False diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 82f08ae6..c2a38aba 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -95,7 +95,7 @@ class AccessControlList(SimComponent): self._acl = [None] * (self.max_acl_rules - 1) def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() # When the request reaches this action, it should now contain solely positional args for the 'add_rule' action. # POSITIONAL ARGUMENTS: @@ -106,7 +106,7 @@ class AccessControlList(SimComponent): # 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2')) # 5: destination port (str name of a Port (e.g. "HTTP")) # 6: position (int) - am.add_request( + rm.add_request( "add_rule", RequestType( func=lambda request, context: self.add_rule( @@ -121,8 +121,8 @@ class AccessControlList(SimComponent): ), ) - am.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0])))) - return am + rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0])))) + return rm def describe_state(self) -> Dict: """ @@ -639,9 +639,9 @@ class Router(Node): self.icmp.arp = self.arp def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() - am.add_request("acl", RequestType(func=self.acl._request_manager)) - return am + rm = super()._init_request_manager() + rm.add_request("acl", RequestType(func=self.acl._request_manager)) + return rm def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]: """ diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index be20f89f..25d1bd21 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -1,7 +1,7 @@ from ipaddress import IPv4Address from primaite.simulator.network.container import Network -from primaite.simulator.network.hardware.base import NIC +from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server @@ -110,19 +110,19 @@ def arcd_uc2_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=5) + router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8) + switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON) switch_1.power_on() network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8) + switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON) switch_2.power_on() network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8]) router_1.enable_port(2) @@ -134,6 +134,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.10.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) client_1.power_on() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) @@ -148,6 +149,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.10.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) client_2.power_on() network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) @@ -158,6 +160,7 @@ def arcd_uc2_network() -> Network: ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, ) domain_controller.power_on() domain_controller.software_manager.install(DNSServer) @@ -171,6 +174,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.1.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) database_server.power_on() network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3]) @@ -244,6 +248,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.1.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) web_server.power_on() web_server.software_manager.install(DatabaseClient) @@ -267,6 +272,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.1.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) backup_server.power_on() backup_server.software_manager.install(FTPServer) @@ -279,6 +285,7 @@ def arcd_uc2_network() -> Network: subnet_mask="255.255.255.0", default_gateway="192.168.1.1", dns_server=IPv4Address("192.168.1.10"), + operating_state=NodeOperatingState.ON, ) security_suite.power_on() network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7]) diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 230daf2c..8e820ec8 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -22,13 +22,13 @@ class Simulation(SimComponent): super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() + rm = super()._init_request_manager() # pass through network requests to the network objects - am.add_request("network", RequestType(func=self.network._request_manager)) + rm.add_request("network", RequestType(func=self.network._request_manager)) # pass through domain requests to the domain object - am.add_request("domain", RequestType(func=self.domain._request_manager)) - am.add_request("do_nothing", RequestType(func=lambda request, context: ())) - return am + rm.add_request("domain", RequestType(func=self.domain._request_manager)) + rm.add_request("do_nothing", RequestType(func=lambda request, context: ())) + return rm def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index c2631455..50386d7c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -3,7 +3,7 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.system.software import IOSoftware +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState _LOGGER = getLogger(__name__) @@ -34,21 +34,29 @@ class Service(IOSoftware): operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED "The current operating state of the Service." + restart_duration: int = 5 "How many timesteps does it take to restart this service." - _restart_countdown: Optional[int] = None + restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.health_state_visible = SoftwareHealthState.UNUSED + self.health_state_actual = SoftwareHealthState.UNUSED + def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() - am.add_request("stop", RequestType(func=lambda request, context: self.stop())) - am.add_request("start", RequestType(func=lambda request, context: self.start())) - am.add_request("pause", RequestType(func=lambda request, context: self.pause())) - am.add_request("resume", RequestType(func=lambda request, context: self.resume())) - am.add_request("restart", RequestType(func=lambda request, context: self.restart())) - am.add_request("disable", RequestType(func=lambda request, context: self.disable())) - am.add_request("enable", RequestType(func=lambda request, context: self.enable())) - return am + rm = super()._init_request_manager() + rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) + rm.add_request("stop", RequestType(func=lambda request, context: self.stop())) + rm.add_request("start", RequestType(func=lambda request, context: self.start())) + rm.add_request("pause", RequestType(func=lambda request, context: self.pause())) + rm.add_request("resume", RequestType(func=lambda request, context: self.resume())) + rm.add_request("restart", RequestType(func=lambda request, context: self.restart())) + rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) + rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) + return rm def describe_state(self) -> Dict: """ @@ -60,7 +68,9 @@ class Service(IOSoftware): :rtype: Dict """ state = super().describe_state() - state.update({"operating_state": self.operating_state.value}) + state["operating_state"] = self.operating_state.value + state["health_state_actual"] = self.health_state_actual + state["health_state_visible"] = self.health_state_visible return state def reset_component_for_episode(self, episode: int): @@ -72,47 +82,59 @@ class Service(IOSoftware): """ pass + def scan(self) -> None: + """Update the service visible states.""" + # update the visible operating state + self.health_state_visible = self.health_state_actual + def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED + self.health_state_actual = SoftwareHealthState.UNUSED def start(self, **kwargs) -> None: """Start the service.""" if self.operating_state == ServiceOperatingState.STOPPED: self.sys_log.info(f"Starting service {self.name}") self.operating_state = ServiceOperatingState.RUNNING + self.health_state_actual = SoftwareHealthState.GOOD def pause(self) -> None: """Pause the service.""" if self.operating_state == ServiceOperatingState.RUNNING: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED + self.health_state_actual = SoftwareHealthState.OVERWHELMED def resume(self) -> None: """Resume paused service.""" if self.operating_state == ServiceOperatingState.PAUSED: self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING + self.health_state_actual = SoftwareHealthState.GOOD def restart(self) -> None: """Restart running service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING - self.restart_countdown = self.restarting_duration + self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.restart_countdown = self.restart_duration def disable(self) -> None: """Disable the service.""" self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED + self.health_state_actual = SoftwareHealthState.OVERWHELMED def enable(self) -> None: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED + self.health_state_actual = SoftwareHealthState.OVERWHELMED def apply_timestep(self, timestep: int) -> None: """ @@ -129,4 +151,5 @@ class Service(IOSoftware): if self.restart_countdown <= 0: _LOGGER.debug(f"Restarting finished for service {self.name}") self.operating_state = ServiceOperatingState.RUNNING + self.health_state_actual = SoftwareHealthState.GOOD self.restart_countdown -= 1 diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 8cd13d1a..b9e58c89 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -31,6 +31,8 @@ class SoftwareType(Enum): class SoftwareHealthState(Enum): """Enumeration of the Software Health States.""" + UNUSED = 0 + "Unused state." GOOD = 1 "The software is in a good and healthy condition." COMPROMISED = 2 @@ -88,15 +90,15 @@ class Software(SimComponent): "The folder on the file system the Software uses." def _init_request_manager(self) -> RequestManager: - am = super()._init_request_manager() - am.add_request( + rm = super()._init_request_manager() + rm.add_request( "compromise", RequestType( func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED), ), ) - am.add_request("scan", RequestType(func=lambda request, context: self.scan())) - return am + rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) + return rm def _get_session_details(self, session_id: str) -> Session: """ @@ -241,6 +243,7 @@ class IOSoftware(Software): payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id ) + @abstractmethod def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. diff --git a/tests/conftest.py b/tests/conftest.py index 35548f2a..d8c9cc50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Union +from typing import Any, Union from unittest.mock import patch import pytest @@ -14,6 +14,9 @@ from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession from primaite.simulator.network.container import Network from primaite.simulator.network.networks import arcd_uc2_network +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.sys_log import SysLog +from primaite.simulator.system.services.service import Service from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -26,11 +29,25 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.base import Node +class TestService(Service): + """Test Service class""" + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + pass + + @pytest.fixture(scope="function") def uc2_network() -> Network: return arcd_uc2_network() +@pytest.fixture(scope="function") +def service(file_system) -> TestService: + return TestService( + name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + ) + + @pytest.fixture(scope="function") def file_system() -> FileSystem: return Node(hostname="fs_node").file_system diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index 85717b25..7da9fe76 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -1,17 +1,15 @@ -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState def test_node_to_node_ping(): """Tests two Nodes are able to ping each other.""" - node_a = Node(hostname="node_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) node_a.connect_nic(nic_a) - node_a.power_on() - node_b = Node(hostname="node_b") + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") node_b.connect_nic(nic_b) - node_b.power_on() Link(endpoint_a=nic_a, endpoint_b=nic_b) @@ -20,22 +18,19 @@ def test_node_to_node_ping(): def test_multi_nic(): """Tests that Nodes with multiple NICs can ping each other and the data go across the correct links.""" - node_a = Node(hostname="node_a") + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) - node_a.power_on() - node_b = Node(hostname="node_b") + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0") node_b.connect_nic(nic_b1) node_b.connect_nic(nic_b2) - node_b.power_on() - node_c = Node(hostname="node_c") + node_c = Node(hostname="node_c", operating_state=NodeOperatingState.ON) nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0") node_c.connect_nic(nic_c) - node_c.power_on() Link(endpoint_a=nic_a, endpoint_b=nic_b1) diff --git a/tests/integration_tests/network/test_link_connection.py b/tests/integration_tests/network/test_link_connection.py index ef65f078..0ddf54df 100644 --- a/tests/integration_tests/network/test_link_connection.py +++ b/tests/integration_tests/network/test_link_connection.py @@ -1,17 +1,15 @@ -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState def test_link_up(): """Tests Nodes, NICs, and Links can all be connected and be in an enabled/up state.""" - node_a = Node(hostname="node_a") + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) - node_a.power_on() - node_b = Node(hostname="node_b") + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") node_b.connect_nic(nic_b) - node_b.power_on() link = Link(endpoint_a=nic_a, endpoint_b=nic_b) diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index cb420e22..6053c457 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -2,7 +2,7 @@ from typing import Tuple import pytest -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -10,18 +10,15 @@ from primaite.simulator.network.transmission.transport_layer import Port @pytest.fixture(scope="function") def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]: - pc_a = Node(hostname="pc_a", default_gateway="192.168.0.1") + pc_a = Node(hostname="pc_a", default_gateway="192.168.0.1", operating_state=NodeOperatingState.ON) nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") pc_a.connect_nic(nic_a) - pc_a.power_on() - pc_b = Node(hostname="pc_b", default_gateway="192.168.1.1") + pc_b = Node(hostname="pc_b", default_gateway="192.168.1.1", operating_state=NodeOperatingState.ON) nic_b = NIC(ip_address="192.168.1.10", subnet_mask="255.255.255.0") pc_b.connect_nic(nic_b) - pc_b.power_on() - router_1 = Router(hostname="router_1") - router_1.power_on() + router_1 = Router(hostname="router_1", operating_state=NodeOperatingState.ON) router_1.configure_port(1, "192.168.0.1", "255.255.255.0") router_1.configure_port(2, "192.168.1.1", "255.255.255.0") diff --git a/tests/integration_tests/network/test_switched_network.py b/tests/integration_tests/network/test_switched_network.py index dc7742f4..5b305702 100644 --- a/tests/integration_tests/network/test_switched_network.py +++ b/tests/integration_tests/network/test_switched_network.py @@ -1,4 +1,4 @@ -from primaite.simulator.network.hardware.base import Link +from primaite.simulator.network.hardware.base import Link, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch @@ -7,17 +7,22 @@ from primaite.simulator.network.hardware.nodes.switch import Switch def test_switched_network(): """Tests a node can ping another node via the switch.""" client_1 = Computer( - hostname="client_1", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.0" + hostname="client_1", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.0", + operating_state=NodeOperatingState.ON, ) - client_1.power_on() server_1 = Server( - hostname=" server_1", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.11" + hostname=" server_1", + ip_address="192.168.1.11", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.11", + operating_state=NodeOperatingState.ON, ) - server_1.power_on() - switch_1 = Switch(hostname="switch_1", num_ports=6) - switch_1.power_on() + switch_1 = Switch(hostname="switch_1", num_ports=6, operating_state=NodeOperatingState.ON) Link(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1]) Link(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[2]) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index d1d78003..2404f30d 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,6 +1,6 @@ import pytest -from primaite.simulator.file_system.file_system import FileSystem +from primaite.simulator.file_system.file_system import File, FileSystem, FileSystemItemHealthStatus, Folder from primaite.simulator.file_system.file_type import FileType @@ -122,6 +122,7 @@ def test_copy_file(file_system): assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid +@pytest.mark.skip(reason="Implementation for quarantine not needed yet") def test_folder_quarantine_state(file_system): """Tests the changing of folder quarantine status.""" folder = file_system.get_folder("root") @@ -135,6 +136,158 @@ def test_folder_quarantine_state(file_system): assert folder.quarantine_status() is False +def test_file_corrupt_repair(file_system): + """Test the ability to corrupt and repair files.""" + folder: Folder = file_system.create_folder(folder_name="test_folder") + file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + + file.corrupt() + + assert folder.health_status == FileSystemItemHealthStatus.GOOD + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + file.repair() + + assert folder.health_status == FileSystemItemHealthStatus.GOOD + assert file.health_status == FileSystemItemHealthStatus.GOOD + + +def test_folder_corrupt_repair(file_system): + """Test the ability to corrupt and repair folders.""" + folder: Folder = file_system.create_folder(folder_name="test_folder") + file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + + folder.corrupt() + + file = folder.get_file(file_name="test_file.txt") + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + folder.repair() + + file = folder.get_file(file_name="test_file.txt") + assert folder.health_status == FileSystemItemHealthStatus.GOOD + assert file.health_status == FileSystemItemHealthStatus.GOOD + + +def test_file_scan(file_system): + """Test the ability to update visible status.""" + folder: Folder = file_system.create_folder(folder_name="test_folder") + file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + + assert file.health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + + file.corrupt() + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + + file.scan() + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_folder_scan(file_system): + """Test the ability to update visible status.""" + folder: Folder = file_system.create_folder(folder_name="test_folder") + file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + file_system.create_file(file_name="test_file2.txt", folder_name="test_folder") + + file1: File = folder.get_file_by_id(file_uuid=list(folder.files)[1]) + file2: File = folder.get_file_by_id(file_uuid=list(folder.files)[0]) + + assert folder.health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + + folder.corrupt() + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + + folder.scan() + + folder.apply_timestep(timestep=0) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + + folder.apply_timestep(timestep=1) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + folder.apply_timestep(timestep=2) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_simulated_file_check_hash(file_system): + file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + + assert file.check_hash() is True + + # change simulated file size + file.sim_size = 0 + assert file.check_hash() is False + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_real_file_check_hash(file_system): + file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder", real=True) + + assert file.check_hash() is True + + # change file content + with open(file.sim_path, "a") as f: + f.write("get hacked scrub lol xD\n") + + assert file.check_hash() is False + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_simulated_folder_check_hash(file_system): + folder: Folder = file_system.create_folder(folder_name="test_folder") + file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + + assert folder.check_hash() is True + + # change simulated file size + file = folder.get_file(file_name="test_file.txt") + file.sim_size = 0 + assert folder.check_hash() is False + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_real_folder_check_hash(file_system): + folder: Folder = file_system.create_folder(folder_name="test_folder") + file_system.create_file(file_name="test_file.txt", folder_name="test_folder", real=True) + + assert folder.check_hash() is True + + # change simulated file size + file = folder.get_file(file_name="test_file.txt") + + # change file content + with open(file.sim_path, "a") as f: + f.write("get hacked scrub lol xD\n") + + assert folder.check_hash() is False + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py new file mode 100644 index 00000000..23115fd7 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py @@ -0,0 +1,160 @@ +from typing import Tuple + +import pytest + +from primaite.simulator.file_system.file_system import File, FileSystem, FileSystemItemHealthStatus, Folder + + +@pytest.fixture(scope="function") +def populated_file_system(file_system) -> Tuple[FileSystem, Folder, File]: + """Test that an agent can request a file scan.""" + folder = file_system.create_folder(folder_name="test_folder") + file = file_system.create_file(folder_name="test_folder", file_name="test_file.txt") + + return file_system, folder, file + + +def test_file_scan_request(populated_file_system): + """Test that an agent can request a file scan.""" + fs, folder, file = populated_file_system + + file.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + + fs.apply_request(request=["file", file.uuid, "scan"]) + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_folder_scan_request(populated_file_system): + """Test that an agent can request a folder scan.""" + fs, folder, file = populated_file_system + fs.create_file(file_name="test_file2.txt", folder_name="test_folder") + + file1: File = folder.get_file_by_id(file_uuid=list(folder.files)[1]) + file2: File = folder.get_file_by_id(file_uuid=list(folder.files)[0]) + + folder.corrupt() + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + + fs.apply_request(request=["folder", folder.uuid, "scan"]) + + folder.apply_timestep(timestep=0) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + + folder.apply_timestep(timestep=1) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + folder.apply_timestep(timestep=2) + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file1.visible_health_status == FileSystemItemHealthStatus.CORRUPT + assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_file_checkhash_request(populated_file_system): + """Test that an agent can request a file hash check.""" + fs, folder, file = populated_file_system + + fs.apply_request(request=["file", file.uuid, "checkhash"]) + + assert file.health_status == FileSystemItemHealthStatus.GOOD + file.sim_size = 0 + + fs.apply_request(request=["file", file.uuid, "checkhash"]) + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_folder_checkhash_request(populated_file_system): + """Test that an agent can request a folder hash check.""" + fs, folder, file = populated_file_system + + fs.apply_request(request=["folder", folder.uuid, "checkhash"]) + + assert folder.health_status == FileSystemItemHealthStatus.GOOD + file.sim_size = 0 + + fs.apply_request(request=["folder", folder.uuid, "checkhash"]) + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_file_repair_request(populated_file_system): + """Test that an agent can request a file repair.""" + fs, folder, file = populated_file_system + + file.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + fs.apply_request(request=["file", file.uuid, "repair"]) + assert file.health_status == FileSystemItemHealthStatus.GOOD + + +def test_folder_repair_request(populated_file_system): + """Test that an agent can request a folder repair.""" + fs, folder, file = populated_file_system + + folder.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + fs.apply_request(request=["folder", folder.uuid, "repair"]) + assert file.health_status == FileSystemItemHealthStatus.GOOD + assert folder.health_status == FileSystemItemHealthStatus.GOOD + + +def test_file_restore_request(populated_file_system): + pass + + +def test_folder_restore_request(populated_file_system): + pass + + +def test_file_corrupt_request(populated_file_system): + """Test that an agent can request a file corruption.""" + fs, folder, file = populated_file_system + fs.apply_request(request=["file", file.uuid, "corrupt"]) + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_folder_corrupt_request(populated_file_system): + """Test that an agent can request a folder corruption.""" + fs, folder, file = populated_file_system + fs.apply_request(request=["folder", folder.uuid, "corrupt"]) + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_file_delete_request(populated_file_system): + """Test that an agent can request a file deletion.""" + fs, folder, file = populated_file_system + assert folder.get_file_by_id(file_uuid=file.uuid) is not None + + fs.apply_request(request=["folder", folder.uuid, "delete", file.uuid]) + assert folder.get_file_by_id(file_uuid=file.uuid) is None + + +def test_folder_delete_request(populated_file_system): + """Test that an agent can request a folder deletion.""" + fs, folder, file = populated_file_system + assert folder.get_file_by_id(file_uuid=file.uuid) is not None + assert fs.get_folder_by_id(folder_uuid=folder.uuid) is not None + + fs.apply_request(request=["delete", folder.uuid]) + assert fs.get_folder_by_id(folder_uuid=folder.uuid) is None + assert folder.get_file_by_id(file_uuid=file.uuid) is None diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py new file mode 100644 index 00000000..e03e1d28 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -0,0 +1,41 @@ +import pytest + +from primaite.simulator.network.hardware.base import Node, NodeOperatingState + + +@pytest.fixture +def node() -> Node: + return Node(hostname="test") + + +def test_node_startup(node): + assert node.operating_state == NodeOperatingState.OFF + node.apply_request(["startup"]) + assert node.operating_state == NodeOperatingState.BOOTING + + idx = 0 + while node.operating_state == NodeOperatingState.BOOTING: + node.apply_timestep(timestep=idx) + idx += 1 + + assert node.operating_state == NodeOperatingState.ON + + +def test_node_shutdown(node): + assert node.operating_state == NodeOperatingState.OFF + node.apply_request(["startup"]) + idx = 0 + while node.operating_state == NodeOperatingState.BOOTING: + node.apply_timestep(timestep=idx) + idx += 1 + + assert node.operating_state == NodeOperatingState.ON + + node.apply_request(["shutdown"]) + + idx = 0 + while node.operating_state == NodeOperatingState.SHUTTING_DOWN: + node.apply_timestep(timestep=idx) + idx += 1 + + assert node.operating_state == NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py new file mode 100644 index 00000000..6b2ee0a7 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py @@ -0,0 +1,80 @@ +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState + + +def test_service_scan(service): + """Test that an agent can request a service scan.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_visible == SoftwareHealthState.UNUSED + + service.apply_request(["scan"]) + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_visible == SoftwareHealthState.GOOD + + +def test_service_stop(service): + """Test that an agent can request to stop a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["stop"]) + assert service.operating_state == ServiceOperatingState.STOPPED + + +def test_service_start(service): + """Test that an agent can request to start a service.""" + assert service.operating_state == ServiceOperatingState.STOPPED + service.apply_request(["start"]) + assert service.operating_state == ServiceOperatingState.RUNNING + + +def test_service_pause(service): + """Test that an agent can request to pause a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["pause"]) + assert service.operating_state == ServiceOperatingState.PAUSED + + +def test_service_resume(service): + """Test that an agent can request to resume a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["pause"]) + assert service.operating_state == ServiceOperatingState.PAUSED + + service.apply_request(["resume"]) + assert service.operating_state == ServiceOperatingState.RUNNING + + +def test_service_restart(service): + """Test that an agent can request to restart a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["restart"]) + assert service.operating_state == ServiceOperatingState.RESTARTING + + +def test_service_disable(service): + """Test that an agent can request to disable a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["disable"]) + assert service.operating_state == ServiceOperatingState.DISABLED + + +def test_service_enable(service): + """Test that an agent can request to enable a service.""" + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.apply_request(["disable"]) + assert service.operating_state == ServiceOperatingState.DISABLED + + service.apply_request(["enable"]) + assert service.operating_state == ServiceOperatingState.STOPPED diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py new file mode 100644 index 00000000..b32463a2 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -0,0 +1,68 @@ +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState + + +def test_scan(service): + assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_visible == SoftwareHealthState.UNUSED + + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_visible == SoftwareHealthState.UNUSED + + service.scan() + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_visible == SoftwareHealthState.GOOD + + +def test_start_service(service): + assert service.operating_state == ServiceOperatingState.STOPPED + service.start() + + assert service.operating_state == ServiceOperatingState.RUNNING + + +def test_stop_service(service): + service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + + service.stop() + assert service.operating_state == ServiceOperatingState.STOPPED + + +def test_pause_and_resume_service(service): + assert service.operating_state == ServiceOperatingState.STOPPED + service.resume() + assert service.operating_state == ServiceOperatingState.STOPPED + + service.start() + service.pause() + assert service.operating_state == ServiceOperatingState.PAUSED + + service.resume() + assert service.operating_state == ServiceOperatingState.RUNNING + + +def test_restart(service): + assert service.operating_state == ServiceOperatingState.STOPPED + service.restart() + assert service.operating_state == ServiceOperatingState.STOPPED + + service.start() + service.restart() + assert service.operating_state == ServiceOperatingState.RESTARTING + + timestep = 0 + while service.operating_state == ServiceOperatingState.RESTARTING: + service.apply_timestep(timestep) + timestep += 1 + + assert service.operating_state == ServiceOperatingState.RUNNING + + +def test_enable_disable(service): + service.disable() + assert service.operating_state == ServiceOperatingState.DISABLED + + service.enable() + assert service.operating_state == ServiceOperatingState.STOPPED