Revert pre-commit deleting files
This commit is contained in:
@@ -1 +1,259 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
"""File observation, provides status information about a file within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FileObservation."""
|
||||
|
||||
file_name: str
|
||||
"""Name of the file, used for querying simulation state dictionary."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to the file in the observation."""
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this file.
|
||||
A typical location for a file might be
|
||||
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
|
||||
:type where: WhereType
|
||||
:param include_num_access: Whether to include the number of accesses to the file in the observation.
|
||||
:type include_num_access: bool
|
||||
:param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.include_num_access: bool = include_num_access
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the health status of the file and optionally the number of accesses.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
if self.file_system_requires_scan:
|
||||
health_status = file_state["visible_status"]
|
||||
else:
|
||||
health_status = file_state["health_status"]
|
||||
obs = {"health_status": health_status}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for file status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = {"health_status": spaces.Discrete(6)}
|
||||
if self.include_num_access:
|
||||
space["num_access"] = spaces.Discrete(4)
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the file observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this file's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed file observation instance.
|
||||
:rtype: FileObservation
|
||||
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
"""Folder observation, provides status information about a folder within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FolderObservation."""
|
||||
|
||||
folder_name: str
|
||||
"""Name of the folder, used for querying simulation state dictionary."""
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
"""List of file configurations within the folder."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations in this folder."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether files in this folder should include the number of accesses in their observation."""
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
files: Iterable[FileObservation],
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
|
||||
:type where: WhereType
|
||||
:param files: List of file observation instances within the folder.
|
||||
:type files: Iterable[FileObservation]
|
||||
:param num_files: Number of files expected in the folder.
|
||||
:type num_files: int
|
||||
:param include_num_access: Whether to include the number of accesses to files in the observation.
|
||||
:type include_num_access: bool
|
||||
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(
|
||||
FileObservation(
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
}
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the health status of the folder and status of files within the folder.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
if self.file_system_requires_scan:
|
||||
health_status = folder_state["visible_status"]
|
||||
else:
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
if self.files:
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for folder status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {"health_status": spaces.Discrete(6)}
|
||||
if self.files:
|
||||
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the folder observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed folder observation instance.
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["file_system", "folders", config.folder_name]
|
||||
|
||||
# pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(
|
||||
where=where,
|
||||
files=files,
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
)
|
||||
|
||||
@@ -1 +1,153 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""Link observation, providing information about a specific link within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinkObservation."""
|
||||
|
||||
link_reference: str
|
||||
"""Reference identifier for the link."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a link observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this link.
|
||||
A typical location for a link might be ['network', 'links', <link_reference>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about the link.
|
||||
:rtype: Any
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
self.where[-1] = "<->".join(self.where[-1].split("<->")[::-1]) # try swapping endpoint A and B
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for link status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinkObservation:
|
||||
"""
|
||||
Create a link observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the link observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this link.
|
||||
A typical location might be ['network', 'links', <link_reference>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed link observation instance.
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
link_reference = config.link_reference
|
||||
if parent_where == []:
|
||||
where = ["network", "links", link_reference]
|
||||
else:
|
||||
where = parent_where + ["links", link_reference]
|
||||
return cls(where=where)
|
||||
|
||||
|
||||
class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
"""Collection of link observations representing multiple links within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinksObservation."""
|
||||
|
||||
link_references: List[str]
|
||||
"""List of reference identifiers for the links."""
|
||||
|
||||
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
|
||||
"""
|
||||
Initialise a links observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for these links.
|
||||
A typical location for links might be ['network', 'links'].
|
||||
:type where: WhereType
|
||||
:param links: List of link observations.
|
||||
:type links: List[LinkObservation]
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.links: List[LinkObservation] = links
|
||||
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about multiple links.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinksObservation:
|
||||
"""
|
||||
Create a links observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the links observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about these links.
|
||||
A typical location might be ['network'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed links observation instance.
|
||||
:rtype: LinksObservation
|
||||
"""
|
||||
where = parent_where + ["network"]
|
||||
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
|
||||
links = [LinkObservation.from_config(c, parent_where=where) for c in link_cfgs]
|
||||
return cls(where=where, links=links)
|
||||
|
||||
@@ -1 +1,164 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ServiceObservation."""
|
||||
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a service observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this service.
|
||||
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the operating status and health status of the service.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": service_state["operating_state"],
|
||||
"health_status": service_state["health_state_visible"],
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for service status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the service observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this service's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
"""Application observation, shows the status of an application within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ApplicationObservation."""
|
||||
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise an application observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this application.
|
||||
A typical location for an application might be
|
||||
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Obs containing the operating status, health status, and number of executions of the application.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
application_state = access_from_nested_dict(state, self.where)
|
||||
if application_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for application status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(5),
|
||||
"num_executions": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the application observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this application's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +1,338 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPClient(FTPServiceABC):
|
||||
"""
|
||||
A class for simulating an FTP client service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
def _send_data_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Request for sending data via the ftp_client using the request options parameters.
|
||||
|
||||
:param request: Request with one element containing a dict of parameters for the send method.
|
||||
:type request: RequestFormat
|
||||
:param context: additional context for resolving this action, currently unused
|
||||
:type context: dict
|
||||
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
dest_ip = request[-1].get("dest_ip_address")
|
||||
dest_ip = None if dest_ip is None else IPv4Address(dest_ip)
|
||||
|
||||
# Missing FTP Options results is an automatic failure.
|
||||
src_folder = request[-1].get("src_folder_name", None)
|
||||
src_file_name = request[-1].get("src_file_name", None)
|
||||
dest_folder = request[-1].get("dest_folder_name", None)
|
||||
dest_file_name = request[-1].get("dest_file_name", None)
|
||||
|
||||
if not self.file_system.access_file(folder_name=src_folder, file_name=src_file_name):
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Received a FTP Request to transfer file: {src_file_name} to Remote IP: {dest_ip}."
|
||||
)
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={
|
||||
"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"
|
||||
},
|
||||
)
|
||||
|
||||
return RequestResponse.from_bool(
|
||||
self.send_file(
|
||||
dest_ip_address=dest_ip,
|
||||
src_folder_name=src_folder,
|
||||
src_file_name=src_file_name,
|
||||
dest_folder_name=dest_folder,
|
||||
dest_file_name=dest_file_name,
|
||||
)
|
||||
)
|
||||
|
||||
rm.add_request("send", request_type=RequestType(func=_send_data_request)),
|
||||
return rm
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
|
||||
:param: payload: The FTP Packet to process
|
||||
:type: payload: FTPPacket
|
||||
:param: session_id: session ID linked to the FTP Packet. Optional.
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
# if client service is down, return error
|
||||
if not self._can_perform_action():
|
||||
payload.status_code = FTPStatusCode.ERROR
|
||||
return payload
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
# process client specific commands, otherwise call super
|
||||
return super()._process_ftp_command(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def _connect_to_server(
|
||||
self,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = PORT_LOOKUP["FTP"],
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: Optional[bool] = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the client to a given FTP server.
|
||||
|
||||
:param: dest_ip_address: IP address of the FTP server the client needs to connect to. Optional.
|
||||
:type: dest_ip_address: Optional[IPv4Address]
|
||||
:param: dest_port: Port of the FTP server the client needs to connect to. Optional.
|
||||
:type: dest_port: Optional[Port]
|
||||
:param: is_reattempt: Set to True if attempt to connect to FTP Server has been attempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
# make sure the service is running before attempting
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
# normally FTP will choose a random port for the transfer, but using the FTP command port will do for now
|
||||
# create FTP packet
|
||||
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=PORT_LOOKUP["FTP"])
|
||||
|
||||
if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id):
|
||||
if payload.status_code == FTPStatusCode.OK:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Successfully connected to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args}"
|
||||
)
|
||||
self.add_connection(connection_id="server_connection", session_id=session_id)
|
||||
return True
|
||||
else:
|
||||
if is_reattempt:
|
||||
# reattempt failed
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Unable to connect to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# try again
|
||||
self._connect_to_server(
|
||||
dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, is_reattempt=True
|
||||
)
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Unable to send FTPPacket")
|
||||
return False
|
||||
|
||||
def _disconnect_from_server(
|
||||
self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = PORT_LOOKUP["FTP"]
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the client from a given FTP server.
|
||||
|
||||
:param: dest_ip_address: IP address of the FTP server the client needs to disconnect from. Optional.
|
||||
:type: dest_ip_address: Optional[IPv4Address]
|
||||
:param: dest_port: Port of the FTP server the client needs to disconnect from. Optional.
|
||||
:type: dest_port: Optional[Port]
|
||||
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
# send a disconnect request payload to FTP server
|
||||
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
|
||||
)
|
||||
return payload.status_code == FTPStatusCode.OK
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
dest_ip_address: IPv4Address,
|
||||
src_folder_name: str,
|
||||
src_file_name: str,
|
||||
dest_folder_name: str,
|
||||
dest_file_name: str,
|
||||
dest_port: Optional[Port] = PORT_LOOKUP["FTP"],
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a file to a target IP address.
|
||||
|
||||
The function checks if the file exists in the FTP Client host.
|
||||
The STOR command is then sent to the FTP Server.
|
||||
|
||||
:param: dest_ip_address: The IP address of the machine that hosts the FTP Server.
|
||||
:type: dest_ip_address: IPv4Address
|
||||
|
||||
:param: src_folder_name: The name of the folder that contains the file to send to the FTP Server.
|
||||
:type: src_folder_name: str
|
||||
|
||||
:param: src_file_name: The name of the file to send to the FTP Server.
|
||||
:type: src_file_name: str
|
||||
|
||||
:param: dest_folder_name: The name of the folder where the file will be stored in the FTP Server.
|
||||
:type: dest_folder_name: str
|
||||
|
||||
:param: dest_file_name: The name of the file to be saved on the FTP Server.
|
||||
:type: dest_file_name: str
|
||||
|
||||
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
|
||||
:type: dest_port: Optional[Port]
|
||||
|
||||
:param: session_id: The id of the session
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
# check if the file to transfer exists on the client
|
||||
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if not file_to_transfer:
|
||||
self.sys_log.warning(f"Unable to send file that does not exist: {src_folder_name}/{src_file_name}")
|
||||
return False
|
||||
|
||||
# check if FTP is currently connected to IP
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
|
||||
# send STOR request
|
||||
if self._send_data(
|
||||
file=file_to_transfer,
|
||||
dest_folder_name=dest_folder_name,
|
||||
dest_file_name=dest_file_name,
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=dest_port,
|
||||
):
|
||||
return self._disconnect_from_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
return False
|
||||
|
||||
def request_file(
|
||||
self,
|
||||
dest_ip_address: IPv4Address,
|
||||
src_folder_name: str,
|
||||
src_file_name: str,
|
||||
dest_folder_name: str,
|
||||
dest_file_name: str,
|
||||
dest_port: Optional[Port] = PORT_LOOKUP["FTP"],
|
||||
) -> bool:
|
||||
"""
|
||||
Request a file from a target IP address.
|
||||
|
||||
Sends a RETR command to the FTP Server.
|
||||
|
||||
:param: dest_ip_address: The IP address of the machine that hosts the FTP Server.
|
||||
:type: dest_ip_address: IPv4Address
|
||||
|
||||
:param: src_folder_name: The name of the folder that contains the file to send to the FTP Server.
|
||||
:type: src_folder_name: str
|
||||
|
||||
:param: src_file_name: The name of the file to send to the FTP Server.
|
||||
:type: src_file_name: str
|
||||
|
||||
:param: dest_folder_name: The name of the folder where the file will be stored in the FTP Server.
|
||||
:type: dest_folder_name: str
|
||||
|
||||
:param: dest_file_name: The name of the file to be saved on the FTP Server.
|
||||
:type: dest_file_name: str
|
||||
|
||||
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
|
||||
:type: dest_port: Optional[int]
|
||||
"""
|
||||
# check if FTP is currently connected to IP
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
# send retrieve request
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.RETR,
|
||||
ftp_command_args={
|
||||
"src_folder_name": src_folder_name,
|
||||
"src_file_name": src_file_name,
|
||||
"dest_file_name": dest_file_name,
|
||||
"dest_folder_name": dest_folder_name,
|
||||
},
|
||||
)
|
||||
self.sys_log.info(f"Requesting file {src_folder_name}/{src_file_name} from {str(dest_ip_address)}")
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
|
||||
)
|
||||
|
||||
# the payload should have ok status code
|
||||
if payload.status_code == FTPStatusCode.OK:
|
||||
self.sys_log.info(f"{self.name}: File {src_folder_name}/{src_file_name} found in FTP server.")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: File {src_folder_name}/{src_file_name} does not exist in FTP server")
|
||||
return False
|
||||
|
||||
def receive(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
:param: payload: FTPPacket payload.
|
||||
:type: payload: FTPPacket
|
||||
|
||||
:param: session_id: ID of the session. Optional.
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
if not isinstance(payload, FTPPacket):
|
||||
self.sys_log.warning(f"{self.name}: Payload is not an FTP packet")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
|
||||
"""
|
||||
Ignore ftp payload if status code is None.
|
||||
|
||||
This helps prevent an FTP request loop - FTP client and servers can exist on
|
||||
the same node.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if payload.status_code is None:
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
|
||||
return False
|
||||
|
||||
# if PORT succeeded, add the connection as an active connection list
|
||||
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.terminate_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
return True
|
||||
|
||||
@@ -1 +1,545 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.protocols.ssh import (
|
||||
SSHConnectionMessage,
|
||||
SSHPacket,
|
||||
SSHTransportMessage,
|
||||
SSHUserCredentials,
|
||||
)
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor
|
||||
# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication
|
||||
class TerminalClientConnection(BaseModel):
|
||||
"""
|
||||
TerminalClientConnection Class.
|
||||
|
||||
This class is used to record current User Connections to the Terminal class.
|
||||
"""
|
||||
|
||||
parent_terminal: Terminal
|
||||
"""The parent Node that this connection was created on."""
|
||||
|
||||
ssh_session_id: str = None
|
||||
"""Session ID that connection is linked to, used for sending commands via session manager."""
|
||||
|
||||
connection_uuid: str = None
|
||||
"""Connection UUID"""
|
||||
|
||||
connection_request_id: str = None
|
||||
"""Connection request ID"""
|
||||
|
||||
time: datetime = None
|
||||
"""Timestamp connection was created."""
|
||||
|
||||
ip_address: IPv4Address
|
||||
"""Source IP of Connection"""
|
||||
|
||||
is_active: bool = True
|
||||
"""Flag to state whether the connection is active or not"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[Terminal]:
|
||||
"""The Terminal that holds this connection."""
|
||||
return self.parent_terminal
|
||||
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect the session."""
|
||||
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: Any) -> bool:
|
||||
"""Execute a given command."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
LocalTerminalConnectionClass.
|
||||
|
||||
This class represents a local terminal when connected.
|
||||
"""
|
||||
|
||||
ip_address: str = "Local Connection"
|
||||
|
||||
def execute(self, command: Any) -> Optional[RequestResponse]:
|
||||
"""Execute a given command on local Terminal."""
|
||||
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
|
||||
return None
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return None
|
||||
return self.parent_terminal.execute(command)
|
||||
|
||||
|
||||
class RemoteTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
RemoteTerminalConnection Class.
|
||||
|
||||
This class acts as broker between the terminal and remote.
|
||||
|
||||
"""
|
||||
|
||||
def execute(self, command: Any) -> bool:
|
||||
"""Execute a given command on the remote Terminal."""
|
||||
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
|
||||
return False
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return False
|
||||
# Send command to remote terminal to process.
|
||||
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_request_uuid=self.connection_request_id,
|
||||
connection_uuid=self.connection_uuid,
|
||||
ssh_command=command,
|
||||
)
|
||||
|
||||
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
|
||||
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
"""Dictionary of connect requests made to remote nodes."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Terminal"
|
||||
kwargs["port"] = PORT_LOOKUP["SSH"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Display the remote connections to this terminal instance in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
self.show_connections(markdown=markdown)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""Initialise Request manager."""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
|
||||
if login:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={
|
||||
"ip_address": str(login.ip_address),
|
||||
"username": request[0],
|
||||
},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request(
|
||||
"ssh_to_remote",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Logoff from remote connection."""
|
||||
ip_address = IPv4Address(request[0])
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = self._disconnect(remote_connection.connection_uuid)
|
||||
if outcome:
|
||||
return RequestResponse(status="success", data={})
|
||||
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff))
|
||||
|
||||
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Execute an instruction."""
|
||||
ip_address: IPv4Address = IPv4Address(request[0])
|
||||
command: str = request[1]["command"]
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = remote_connection.execute(command)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={},
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"send_remote_command",
|
||||
request_type=RequestType(func=remote_execute_request),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
|
||||
"""Execute a passed ssh command via the request manager."""
|
||||
return self.parent.apply_request(command)
|
||||
|
||||
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
|
||||
"""Find Remote Terminal Connection from a given IP."""
|
||||
for connection in self._connections.values():
|
||||
if connection.ip_address == ip_address:
|
||||
return connection
|
||||
|
||||
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
|
||||
"""Create a new connection object and amend to list of active connections.
|
||||
|
||||
:param connection_uuid: Connection ID of the new local connection
|
||||
:param session_id: Session ID of the new local connection
|
||||
:return: TerminalClientConnection object
|
||||
"""
|
||||
new_connection = LocalTerminalConnection(
|
||||
parent_terminal=self,
|
||||
connection_uuid=connection_uuid,
|
||||
ssh_session_id=session_id,
|
||||
time=datetime.now(),
|
||||
)
|
||||
self._connections[connection_uuid] = new_connection
|
||||
self._client_connection_requests[connection_uuid] = new_connection
|
||||
|
||||
return new_connection
|
||||
|
||||
def login(
|
||||
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
|
||||
) -> Optional[TerminalClientConnection]:
|
||||
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local.
|
||||
|
||||
:param: username: Username used to connect to the remote node.
|
||||
:type: username: str
|
||||
|
||||
:param: password: Password used to connect to the remote node
|
||||
:type: password: str
|
||||
|
||||
:param: ip_address: Target Node IP address for login attempt. If None, login is assumed local.
|
||||
:type: ip_address: Optional[IPv4Address]
|
||||
"""
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
|
||||
return None
|
||||
if ip_address:
|
||||
# Assuming that if IP is passed we are connecting to remote
|
||||
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
|
||||
else:
|
||||
return self._process_local_login(username=username, password=password)
|
||||
|
||||
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
|
||||
"""Local session login to terminal.
|
||||
|
||||
:param username: Username for login.
|
||||
:param password: Password for login.
|
||||
:return: boolean, True if successful, else False
|
||||
"""
|
||||
# TODO: Un-comment this when UserSessionManager is merged.
|
||||
connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password)
|
||||
if connection_uuid:
|
||||
self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}")
|
||||
# Add new local session to list of connections and return
|
||||
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password")
|
||||
return None
|
||||
|
||||
def _validate_client_connection_request(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return connection_id in self._client_connection_requests
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id):
|
||||
self._disconnect(connection_id)
|
||||
return False
|
||||
return connection_id in self._connections
|
||||
|
||||
def _send_remote_login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
ip_address: IPv4Address,
|
||||
connection_request_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Optional[RemoteTerminalConnection]:
|
||||
"""Send a remote login attempt and connect to Node.
|
||||
|
||||
:param: username: Username used to connect to the remote node.
|
||||
:type: username: str
|
||||
:param: password: Password used to connect to the remote node
|
||||
:type: password: str
|
||||
:param: ip_address: Target Node IP address for login attempt.
|
||||
:type: ip_address: IPv4Address
|
||||
:param: connection_request_id: Connection Request ID, if not provided, a new one is generated
|
||||
:type: connection_request_id: Optional[str]
|
||||
:param: is_reattempt: True if the request has been reattempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
|
||||
"""
|
||||
connection_request_id = connection_request_id or str(uuid4())
|
||||
if is_reattempt:
|
||||
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
||||
if valid_connection_request:
|
||||
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
if isinstance(remote_terminal_connection, RemoteTerminalConnection):
|
||||
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
|
||||
return remote_terminal_connection
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined")
|
||||
return None
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
|
||||
return None
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
|
||||
)
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
|
||||
|
||||
payload_contents = {
|
||||
"type": "login_request",
|
||||
"username": username,
|
||||
"password": password,
|
||||
"connection_request_id": connection_request_id,
|
||||
}
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=payload_contents,
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
user_account=user_details,
|
||||
connection_request_uuid=connection_request_id,
|
||||
)
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=ip_address, dest_port=self.port
|
||||
)
|
||||
return self._send_remote_login(
|
||||
username=username,
|
||||
password=password,
|
||||
ip_address=ip_address,
|
||||
is_reattempt=True,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def _create_remote_connection(
|
||||
self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str
|
||||
) -> None:
|
||||
"""Create a new TerminalClientConnection Object.
|
||||
|
||||
:param: connection_request_id: Connection Request ID
|
||||
:type: connection_request_id: str
|
||||
|
||||
:param: session_id: Session ID of connection.
|
||||
:type: session_id: str
|
||||
"""
|
||||
client_connection = RemoteTerminalConnection(
|
||||
parent_terminal=self,
|
||||
ssh_session_id=session_id,
|
||||
connection_uuid=connection_id,
|
||||
ip_address=source_ip,
|
||||
connection_request_id=connection_request_id,
|
||||
time=datetime.now(),
|
||||
)
|
||||
self._connections[connection_id] = client_connection
|
||||
self._client_connection_requests[connection_request_id] = client_connection
|
||||
|
||||
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
|
||||
:param payload: A payload to receive.
|
||||
:param session_id: The session id the payload relates to.
|
||||
:return: True.
|
||||
"""
|
||||
source_ip = kwargs["frame"].ip.src_ip_address
|
||||
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
|
||||
if isinstance(payload, SSHPacket):
|
||||
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate & add connection
|
||||
# TODO: uncomment this as part of 2781
|
||||
username = payload.user_account.username
|
||||
password = payload.user_account.password
|
||||
connection_id = self.parent.user_session_manager.remote_login(
|
||||
username=username, password=password, remote_ip_address=source_ip
|
||||
)
|
||||
if connection_id:
|
||||
connection_request_id = payload.connection_request_uuid
|
||||
self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}")
|
||||
self._create_remote_connection(
|
||||
connection_id=connection_id,
|
||||
connection_request_id=connection_request_id,
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
|
||||
transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
|
||||
connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
|
||||
payload_contents = {
|
||||
"type": "login_success",
|
||||
"username": username,
|
||||
"password": password,
|
||||
"connection_request_id": connection_request_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=payload_contents,
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_request_uuid=connection_request_id,
|
||||
connection_uuid=connection_id,
|
||||
)
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
|
||||
self.sys_log.info(f"{self.name}: Login Successful")
|
||||
self._create_remote_connection(
|
||||
connection_id=payload.connection_uuid,
|
||||
connection_request_id=payload.connection_request_uuid,
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
|
||||
# Requesting a command to be executed
|
||||
self.sys_log.info(f"{self.name}: Received command to execute")
|
||||
command = payload.ssh_command
|
||||
valid_connection = self._check_client_connection(payload.connection_uuid)
|
||||
if valid_connection:
|
||||
remote_session = self.software_manager.node.user_session_manager.remote_sessions.get(
|
||||
payload.connection_uuid
|
||||
)
|
||||
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
|
||||
self.execute(command)
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
|
||||
)
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "disconnect":
|
||||
connection_id = payload["connection_id"]
|
||||
valid_id = self._check_client_connection(connection_id)
|
||||
if valid_id:
|
||||
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
|
||||
self._disconnect(payload["connection_id"])
|
||||
self.parent.user_session_manager.remote_logout(remote_session_id=connection_id)
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.")
|
||||
|
||||
if payload["type"] == "user_timeout":
|
||||
connection_id = payload["connection_id"]
|
||||
valid_id = connection_id in self._connections
|
||||
if valid_id:
|
||||
connection = self._connections.pop(connection_id)
|
||||
connection.is_active = False
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.")
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.")
|
||||
|
||||
return True
|
||||
|
||||
def _disconnect(self, connection_uuid: str) -> bool:
|
||||
"""Disconnect connection.
|
||||
|
||||
:param connection_uuid: Connection ID that we want to disconnect.
|
||||
:return True if successful, False otherwise.
|
||||
"""
|
||||
# TODO: Handle the possibility of attempting to disconnect
|
||||
if not self._connections:
|
||||
self.sys_log.warning(f"{self.name}: No remote connection present")
|
||||
return False
|
||||
|
||||
connection = self._connections.pop(connection_uuid, None)
|
||||
if not connection:
|
||||
return False
|
||||
connection.is_active = False
|
||||
|
||||
if isinstance(connection, RemoteTerminalConnection):
|
||||
# Send disconnect command via software manager
|
||||
session_id = connection.ssh_session_id
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_uuid},
|
||||
dest_port=self.port,
|
||||
session_id=session_id,
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
|
||||
return True
|
||||
|
||||
elif isinstance(connection, LocalTerminalConnection):
|
||||
self.parent.user_session_manager.local_logout()
|
||||
return True
|
||||
|
||||
def send(
|
||||
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload out from the Terminal.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_up_address: The IP address of the payload destination.
|
||||
"""
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!")
|
||||
return False
|
||||
|
||||
self.sys_log.debug(f"{self.name}: Sending payload: {payload}")
|
||||
return super().send(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
|
||||
@@ -1 +1,109 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def populated_node(application_class) -> Tuple[Application, Computer]:
|
||||
computer: Computer = Computer(
|
||||
hostname="test_computer",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
shut_down_duration=0,
|
||||
)
|
||||
computer.power_on()
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app = computer.software_manager.software.get("DummyApplication")
|
||||
app.run()
|
||||
|
||||
return app, computer
|
||||
|
||||
|
||||
def test_application_on_offline_node(application_class):
|
||||
"""Test to check that the application cannot be interacted with when node it is on is off."""
|
||||
computer: Computer = Computer(
|
||||
hostname="test_computer",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
shut_down_duration=0,
|
||||
)
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app: Application = computer.software_manager.software.get("DummyApplication")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
app.run()
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_server_turns_off_application(populated_node):
|
||||
"""Check that the application is turned off when the server is turned off"""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
computer.power_off()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_application_cannot_be_turned_on_when_computer_is_off(populated_node):
|
||||
"""Check that the application cannot be started when the computer is off."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
computer.power_off()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
app.run()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
def test_computer_runs_applications(populated_node):
|
||||
"""Check that turning on the computer will turn on applications."""
|
||||
app, computer = populated_node
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
computer.power_off()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
computer.power_on()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
computer.power_off()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.OFF
|
||||
assert app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
computer.power_on()
|
||||
|
||||
assert computer.operating_state is NodeOperatingState.ON
|
||||
assert app.operating_state is ApplicationOperatingState.RUNNING
|
||||
|
||||
Reference in New Issue
Block a user