Merge 'origin/dev-game-layer' into feature/1924-Agent-Interface

This commit is contained in:
Marek Wolan
2023-10-25 09:58:04 +01:00
24 changed files with 1111 additions and 196 deletions

View File

@@ -173,9 +173,9 @@ class SimComponent(BaseModel):
class WebBrowser(Application):
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager() # all requests generic to any Application get initialised
am.add_request(...) # initialise any requests specific to the web browser
return am
rm = super()._init_request_manager() # all requests generic to any Application get initialised
rm.add_request(...) # initialise any requests specific to the web browser
return rm
:return: Request manager object belonging to this sim component.
:rtype: RequestManager

View File

@@ -80,17 +80,17 @@ class DomainController(SimComponent):
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
# Action 'account' matches requests like:
# ['account', '<account-uuid>', *account_action]
am.add_request(
rm.add_request(
"account",
RequestType(
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)
return am
return rm
def describe_state(self) -> Dict:
"""

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
import hashlib
import json
import math
import os.path
import shutil
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
@@ -47,10 +50,19 @@ class FileSystemItemHealthStatus(Enum):
"""Health status for folders and files."""
GOOD = 1
"""File/Folder is OK."""
COMPROMISED = 2
"""File/Folder is quarantined."""
CORRUPT = 3
"""File/Folder is corrupted."""
RESTORING = 4
"""File/Folder is in the process of being restored."""
REPAIRING = 5
"""File/Folder is in the process of being repaired."""
class FileSystemItemABC(SimComponent):
@@ -64,6 +76,15 @@ class FileSystemItemABC(SimComponent):
"The name of the FileSystemItemABC."
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
"Actual status of the current FileSystemItem"
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
"Visible status of the current FileSystemItem"
previous_hash: Optional[str] = None
"Hash of the file contents or the description state"
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -71,9 +92,24 @@ class FileSystemItemABC(SimComponent):
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state.update({"name": self.name, "health_status": self.health_status.value})
state["name"] = self.name
state["status"] = self.health_status.value
state["visible_status"] = self.visible_health_status.value
state["previous_hash"] = self.previous_hash
return state
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan()))
rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash()))
rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair()))
rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore()))
rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt()))
return rm
@property
def size_str(self) -> str:
"""
@@ -86,6 +122,39 @@ class FileSystemItemABC(SimComponent):
"""
return convert_size(self.size)
@abstractmethod
def check_hash(self) -> bool:
"""
Checks the has of the file to detect any changes.
For current implementation, any change in file hash means it is compromised.
Return False if corruption is detected, otherwise True
"""
pass
@abstractmethod
def repair(self) -> bool:
"""
Repair the FileSystemItem.
True if successfully repaired. False otherwise.
"""
pass
@abstractmethod
def corrupt(self) -> bool:
"""
Corrupt the FileSystemItem.
True if successfully corrupted. False otherwise.
"""
pass
def restore(self) -> None:
"""Restore the file/folder to the state before it got ruined."""
pass
class FileSystem(SimComponent):
"""Class that contains all the simulation File System."""
@@ -103,15 +172,20 @@ class FileSystem(SimComponent):
self.create_folder("root")
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
rm.add_request(
name="delete",
request_type=RequestType(func=lambda request, context: self.delete_folder_by_id(folder_uuid=request[0])),
)
self._folder_request_manager = RequestManager()
am.add_request("folder", RequestType(func=self._folder_request_manager))
rm.add_request("folder", RequestType(func=self._folder_request_manager))
self._file_request_manager = RequestManager()
am.add_request("file", RequestType(func=self._file_request_manager))
rm.add_request("file", RequestType(func=self._file_request_manager))
return am
return rm
@property
def size(self) -> int:
@@ -173,7 +247,9 @@ class FileSystem(SimComponent):
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
self._folder_request_manager.add_request(folder.uuid, RequestType(func=folder._request_manager))
self._folder_request_manager.add_request(
name=folder.uuid, request_type=RequestType(func=folder._request_manager)
)
return folder
def delete_folder(self, folder_name: str):
@@ -187,15 +263,29 @@ class FileSystem(SimComponent):
return
folder = self._folders_by_name.get(folder_name)
if folder:
for file in folder.files.values():
self.delete_file(file)
# remove from folder list
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
self._folder_request_manager.remove_request(folder.uuid)
folder.remove_all_files()
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
def delete_folder_by_id(self, folder_uuid: str):
"""
Deletes a folder via its uuid.
:param: folder_uuid: UUID of the folder to delete
"""
folder = self.get_folder_by_id(folder_uuid=folder_uuid)
self.delete_folder(folder_name=folder.name)
def restore_folder(self, folder_id: str):
"""TODO."""
pass
def create_file(
self,
file_name: str,
@@ -234,7 +324,7 @@ class FileSystem(SimComponent):
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
self._file_request_manager.add_request(name=file.uuid, request_type=RequestType(func=file._request_manager))
return file
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
@@ -309,6 +399,20 @@ class FileSystem(SimComponent):
new_file.sim_path.parent.mkdir(exist_ok=True)
shutil.copy2(file.sim_path, new_file.sim_path)
def restore_file(self, folder_id: str, file_id: str):
"""
Restore a file.
Checks the current file's status and applies the correct fix for the file.
:param: folder_id: id of the folder where the file is stored
:type: folder_id: str
:param: folder_id: id of the file to restore
:type: folder_id: str
"""
pass
def get_folder(self, folder_name: str) -> Optional[Folder]:
"""
Get a folder by its name if it exists.
@@ -322,7 +426,7 @@ class FileSystem(SimComponent):
"""
Get a folder by its uuid if it exists.
:param folder_uuid: The folder uuid.
:param: folder_uuid: The folder uuid.
:return: The matching Folder.
"""
return self.folders.get(folder_uuid)
@@ -337,20 +441,17 @@ class Folder(FileSystemItemABC):
"Files stored in the folder."
_files_by_name: Dict[str, File] = {}
"Files by their name as <file name>.<file type>."
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
scan_duration: int = -1
"How many timesteps to complete a scan."
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am
rm = super()._init_request_manager()
rm.add_request(
name="delete",
request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])),
)
return rm
def describe_state(self) -> Dict:
"""
@@ -360,7 +461,6 @@ class Folder(FileSystemItemABC):
"""
state = super().describe_state()
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
state["is_quarantined"] = self.is_quarantined
return state
def show(self, markdown: bool = False):
@@ -388,6 +488,27 @@ class Folder(FileSystemItemABC):
"""
return sum(file.size for file in self.files.values() if file.size is not None)
def apply_timestep(self, timestep: int):
"""
Apply a single timestep of simulation dynamics to this service.
In this instance, if any multi-timestep processes are currently occurring (such as scanning),
then they are brought one step closer to being finished.
:param timestep: The current timestep number. (Amount of time since simulation episode began)
:type timestep: int
"""
super().apply_timestep(timestep=timestep)
# scan files each timestep
if self.scan_duration > -1:
# scan one file per timestep
file = self.get_file_by_id(file_uuid=list(self.files)[self.scan_duration - 1])
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
self.scan_duration -= 1
def get_file(self, file_name: str) -> Optional[File]:
"""
Get a file by its name.
@@ -404,7 +525,7 @@ class Folder(FileSystemItemABC):
"""
Get a file by its uuid.
:param file_uuid: The file uuid.
:param: file_uuid: The file uuid.
:return: The matching File.
"""
return self.files.get(file_uuid)
@@ -446,21 +567,121 @@ class Folder(FileSystemItemABC):
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
def remove_file_by_id(self, file_uuid: str):
"""
Remove a file using id.
:param: file_uuid: The UUID of the file to remove.
"""
file = self.get_file_by_id(file_uuid=file_uuid)
self.remove_file(file=file)
def remove_all_files(self):
"""Removes all the files in the folder."""
self.files = {}
self._files_by_name = {}
def restore_file(self, file: Optional[File]):
"""
Restores a file.
The method can take a File object or a file id.
:param file: The file to remove
"""
pass
def quarantine(self):
"""Quarantines the File System Folder."""
if not self.is_quarantined:
self.is_quarantined = True
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
pass
def unquarantine(self):
"""Unquarantine of the File System Folder."""
if self.is_quarantined:
self.is_quarantined = False
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
pass
def quarantine_status(self) -> bool:
"""Returns true if the folder is being quarantined."""
return self.is_quarantined
pass
def scan(self) -> None:
"""Update Folder visible status."""
if self.scan_duration <= -1:
# scan one file per timestep
self.scan_duration = len(self.files)
self.fs.sys_log.info(f"Scanning folder {self.name} (id: {self.uuid})")
else:
# scan already in progress
self.fs.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})")
def check_hash(self) -> bool:
"""
Runs a :func:`check_hash` on all files in the folder.
If a file under the folder is corrupted, the whole folder is considered corrupted.
TODO: For now this will just iterate through the files and run :func:`check_hash` and ignores
any other changes to the folder
Return False if corruption is detected, otherwise True
"""
super().check_hash()
# iterate through the files and run a check hash
no_corrupted_files = True
for file_id in self.files:
file = self.get_file_by_id(file_uuid=file_id)
no_corrupted_files = file.check_hash()
# if one file in the folder is corrupted, set the folder status to corrupted
if not no_corrupted_files:
self.corrupt()
self.fs.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})")
return no_corrupted_files
def repair(self) -> bool:
"""Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD."""
super().repair()
repaired = False
# iterate through the files in the folder
for file_id in self.files:
file = self.get_file_by_id(file_uuid=file_id)
repaired = file.repair()
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
repaired = True
self.fs.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})")
return repaired
def restore(self) -> None:
"""TODO."""
pass
def corrupt(self) -> bool:
"""Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT."""
super().corrupt()
corrupted = False
# iterate through the files in the folder
for file_id in self.files:
file = self.get_file_by_id(file_uuid=file_id)
corrupted = file.corrupt()
# set file status to corrupt if good
if self.health_status == FileSystemItemHealthStatus.GOOD:
self.health_status = FileSystemItemHealthStatus.CORRUPT
corrupted = True
self.fs.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})")
return corrupted
class File(FileSystemItemABC):
@@ -517,18 +738,6 @@ class File(FileSystemItemABC):
with open(self.sim_path, mode="a"):
pass
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am
def make_copy(self, dst_folder: Folder) -> File:
"""
Create a copy of the current File object in the given destination folder.
@@ -564,3 +773,73 @@ class File(FileSystemItemABC):
state["size"] = self.size
state["file_type"] = self.file_type.name
return state
def scan(self) -> None:
"""Updates the visible statuses of the file."""
path = self.folder.name + "/" + self.name
self.folder.fs.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
self.visible_health_status = self.health_status
def check_hash(self) -> bool:
"""
Check if the file has been changed.
If changed, the file is considered corrupted.
Return False if corruption is detected, otherwise True
"""
current_hash = None
# if file is real, read the file contents
if self.real:
with open(self.sim_path, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(8192):
file_hash.update(chunk)
current_hash = file_hash.hexdigest()
else:
# otherwise get describe_state dict and hash that
current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest()
# if the previous hash is None, set the current hash to previous
if self.previous_hash is None:
self.previous_hash = current_hash
# if the previous hash and current hash do not match, mark file as corrupted
if self.previous_hash is not current_hash:
self.corrupt()
return False
return True
def repair(self) -> bool:
"""Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
super().repair()
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
path = self.folder.name + "/" + self.name
self.folder.fs.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
return True
def restore(self) -> None:
"""Restore a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
pass
def corrupt(self) -> bool:
"""Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT."""
super().corrupt()
corrupted = False
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.GOOD:
self.health_status = FileSystemItemHealthStatus.CORRUPT
corrupted = True
path = self.folder.name + "/" + self.name
self.folder.fs.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
return corrupted

View File

@@ -44,13 +44,13 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()
am.add_request(
rm.add_request(
"node",
RequestType(func=self._node_request_manager),
)
return am
return rm
@property
def routers(self) -> List[Router]:

View File

@@ -145,12 +145,12 @@ class NIC(SimComponent):
return state
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
return am
return rm
@property
def ip_network(self) -> IPv4Network:
@@ -914,6 +914,18 @@ class Node(SimComponent):
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.
@@ -952,30 +964,30 @@ class Node(SimComponent):
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
am = super()._init_request_manager()
rm = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
self._service_request_manager = RequestManager()
am.add_request("service", RequestType(func=self._service_request_manager))
rm.add_request("service", RequestType(func=self._service_request_manager))
self._nic_request_manager = RequestManager()
am.add_request("nic", RequestType(func=self._nic_request_manager))
rm.add_request("nic", RequestType(func=self._nic_request_manager))
am.add_request("file_system", RequestType(func=self.file_system._request_manager))
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
# currently we don't have any applications nor processes, so these will be empty
self._process_request_manager = RequestManager()
am.add_request("process", RequestType(func=self._process_request_manager))
rm.add_request("process", RequestType(func=self._process_request_manager))
self._application_request_manager = RequestManager()
am.add_request("application", RequestType(func=self._application_request_manager))
rm.add_request("application", RequestType(func=self._application_request_manager))
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan
rm.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan
am.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
am.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
am.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset
am.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
am.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
rm.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
rm.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset
rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
return am
return rm
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
@@ -1001,6 +1013,7 @@ class Node(SimComponent):
"applications": {uuid: app.describe_state() for uuid, app in self.applications.items()},
"services": {uuid: svc.describe_state() for uuid, svc in self.services.items()},
"process": {uuid: proc.describe_state() for uuid, proc in self.processes.items()},
"revealed_to_red": self.revealed_to_red,
}
)
return state
@@ -1042,9 +1055,45 @@ class Node(SimComponent):
)
print(table)
def apply_timestep(self, timestep: int):
"""
Apply a single timestep of simulation dynamics to this service.
In this instance, if any multi-timestep processes are currently occurring
(such as starting up or shutting down), then they are brought one step closer to
being finished.
:param timestep: The current timestep number. (Amount of time since simulation episode began)
:type timestep: int
"""
super().apply_timestep(timestep=timestep)
# count down to boot up
if self.start_up_countdown > 0:
self.start_up_countdown -= 1
else:
if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic._connected_link:
nic.enable()
# count down to shut down
if self.shut_down_countdown > 0:
self.shut_down_countdown -= 1
else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Turned off")
def power_on(self):
"""Power on the Node, enabling its NICs if it is in the OFF state."""
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
if self.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
@@ -1056,6 +1105,10 @@ class Node(SimComponent):
if self.operating_state == NodeOperatingState.ON:
for nic in self.nics.values():
nic.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration
if self.shut_down_duration <= 0:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Turned off")
@@ -1135,7 +1188,7 @@ class Node(SimComponent):
f"Ping statistics for {target_ip_address}: "
f"Packets: Sent = {pings}, "
f"Received = {request_replies}, "
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
f"Lost = {pings - request_replies} ({(pings - request_replies) / pings * 100}% loss)"
)
return passed
return False

View File

@@ -95,7 +95,7 @@ class AccessControlList(SimComponent):
self._acl = [None] * (self.max_acl_rules - 1)
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
@@ -106,7 +106,7 @@ class AccessControlList(SimComponent):
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
am.add_request(
rm.add_request(
"add_rule",
RequestType(
func=lambda request, context: self.add_rule(
@@ -121,8 +121,8 @@ class AccessControlList(SimComponent):
),
)
am.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
return am
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
return rm
def describe_state(self) -> Dict:
"""
@@ -639,9 +639,9 @@ class Router(Node):
self.icmp.arp = self.arp
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("acl", RequestType(func=self.acl._request_manager))
return am
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
return rm
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
"""

View File

@@ -1,7 +1,7 @@
from ipaddress import IPv4Address
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
@@ -110,19 +110,19 @@ def arcd_uc2_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5)
router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8)
switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
switch_1.power_on()
network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8)
switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON)
switch_2.power_on()
network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])
router_1.enable_port(2)
@@ -134,6 +134,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
@@ -148,6 +149,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
client_2.power_on()
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
@@ -158,6 +160,7 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
@@ -171,6 +174,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
database_server.power_on()
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
@@ -244,6 +248,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
@@ -267,6 +272,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
@@ -279,6 +285,7 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
security_suite.power_on()
network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])

View File

@@ -22,13 +22,13 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
rm = super()._init_request_manager()
# pass through network requests to the network objects
am.add_request("network", RequestType(func=self.network._request_manager))
rm.add_request("network", RequestType(func=self.network._request_manager))
# pass through domain requests to the domain object
am.add_request("domain", RequestType(func=self.domain._request_manager))
am.add_request("do_nothing", RequestType(func=lambda request, context: ()))
return am
rm.add_request("domain", RequestType(func=self.domain._request_manager))
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
return rm
def describe_state(self) -> Dict:
"""

View File

@@ -3,7 +3,7 @@ from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.software import IOSoftware
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -34,21 +34,29 @@ class Service(IOSoftware):
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
"The current operating state of the Service."
restart_duration: int = 5
"How many timesteps does it take to restart this service."
_restart_countdown: Optional[int] = None
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("stop", RequestType(func=lambda request, context: self.stop()))
am.add_request("start", RequestType(func=lambda request, context: self.start()))
am.add_request("pause", RequestType(func=lambda request, context: self.pause()))
am.add_request("resume", RequestType(func=lambda request, context: self.resume()))
am.add_request("restart", RequestType(func=lambda request, context: self.restart()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
return am
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
rm.add_request("stop", RequestType(func=lambda request, context: self.stop()))
rm.add_request("start", RequestType(func=lambda request, context: self.start()))
rm.add_request("pause", RequestType(func=lambda request, context: self.pause()))
rm.add_request("resume", RequestType(func=lambda request, context: self.resume()))
rm.add_request("restart", RequestType(func=lambda request, context: self.restart()))
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
return rm
def describe_state(self) -> Dict:
"""
@@ -60,7 +68,9 @@ class Service(IOSoftware):
:rtype: Dict
"""
state = super().describe_state()
state.update({"operating_state": self.operating_state.value})
state["operating_state"] = self.operating_state.value
state["health_state_actual"] = self.health_state_actual
state["health_state_visible"] = self.health_state_visible
return state
def reset_component_for_episode(self, episode: int):
@@ -72,47 +82,59 @@ class Service(IOSoftware):
"""
pass
def scan(self) -> None:
"""Update the service visible states."""
# update the visible operating state
self.health_state_visible = self.health_state_actual
def stop(self) -> None:
"""Stop the service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
self.health_state_actual = SoftwareHealthState.UNUSED
def start(self, **kwargs) -> None:
"""Start the service."""
if self.operating_state == ServiceOperatingState.STOPPED:
self.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
def pause(self) -> None:
"""Pause the service."""
if self.operating_state == ServiceOperatingState.RUNNING:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def resume(self) -> None:
"""Resume paused service."""
if self.operating_state == ServiceOperatingState.PAUSED:
self.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
def restart(self) -> None:
"""Restart running service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
self.restart_countdown = self.restarting_duration
self.health_state_actual = SoftwareHealthState.OVERWHELMED
self.restart_countdown = self.restart_duration
def disable(self) -> None:
"""Disable the service."""
self.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def enable(self) -> None:
"""Enable the disabled service."""
if self.operating_state == ServiceOperatingState.DISABLED:
self.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
self.health_state_actual = SoftwareHealthState.OVERWHELMED
def apply_timestep(self, timestep: int) -> None:
"""
@@ -129,4 +151,5 @@ class Service(IOSoftware):
if self.restart_countdown <= 0:
_LOGGER.debug(f"Restarting finished for service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
self.restart_countdown -= 1

View File

@@ -31,6 +31,8 @@ class SoftwareType(Enum):
class SoftwareHealthState(Enum):
"""Enumeration of the Software Health States."""
UNUSED = 0
"Unused state."
GOOD = 1
"The software is in a good and healthy condition."
COMPROMISED = 2
@@ -88,15 +90,15 @@ class Software(SimComponent):
"The folder on the file system the Software uses."
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request(
rm = super()._init_request_manager()
rm.add_request(
"compromise",
RequestType(
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
),
)
am.add_request("scan", RequestType(func=lambda request, context: self.scan()))
return am
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
return rm
def _get_session_details(self, session_id: str) -> Session:
"""
@@ -241,6 +243,7 @@ class IOSoftware(Software):
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
)
@abstractmethod
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.