Backport 3.3.1 fixes into Core

This commit is contained in:
Marek Wolan
2025-01-21 13:08:36 +00:00
parent 4b79c88ae5
commit 66daab3baf
57 changed files with 441 additions and 247 deletions

View File

@@ -161,8 +161,8 @@ agents:
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
- HTTP
- POSTGRES_SERVER
protocol_list:
- ICMP
- TCP

View File

@@ -153,8 +153,8 @@ agents:
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
- HTTP
- POSTGRES_SERVER
protocol_list:
- ICMP
- TCP
@@ -668,8 +668,8 @@ agents:
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
- HTTP
- POSTGRES_SERVER
protocol_list:
- ICMP
- TCP

View File

@@ -93,7 +93,7 @@ class AgentLog:
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal:
print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}")
print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}")
def debug(self, msg: str, to_terminal: bool = False):
"""

View File

@@ -24,8 +24,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[int]] = None
"""List of port numbers."""
port_list: Optional[List[str]] = None
"""List of port names."""
protocol_list: Optional[List[str]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
@@ -37,7 +37,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
num_rules: int,
ip_list: List[IPv4Address],
wildcard_list: List[str],
port_list: List[int],
port_list: List[str],
protocol_list: List[str],
) -> None:
"""
@@ -51,8 +51,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
:type ip_list: List[IPv4Address]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param port_list: List of port names.
:type port_list: List[str]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
"""
@@ -60,7 +60,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i

View File

@@ -190,6 +190,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
self.cached_obs: Optional[ObsType] = self.default_observation
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -204,7 +206,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
return self.default_observation
if self.file_system_requires_scan:
health_status = folder_state["visible_status"]
if not folder_state["scanned_this_step"]:
health_status = self.cached_obs["health_status"]
else:
health_status = folder_state["visible_status"]
else:
health_status = folder_state["health_status"]

View File

@@ -27,13 +27,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -41,7 +41,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
port_list: List[str],
protocol_list: List[str],
num_rules: int,
include_users: bool,
@@ -56,8 +56,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type ip_list: List[str]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param port_list: List of port names.
:type port_list: List[str]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
@@ -140,6 +140,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -153,29 +155,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
is_on = firewall_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -186,34 +194,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Gymnasium space representing the observation space for firewall status.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
)
return space
shape = {
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:

View File

@@ -54,7 +54,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -191,25 +191,31 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
is_on = node_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {}
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
obs["operating_status"] = node_state["operating_state"]
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property

View File

@@ -56,7 +56,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""

View File

@@ -33,13 +33,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -84,6 +84,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -98,16 +100,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
is_on = router_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -121,6 +128,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod

View File

@@ -258,6 +258,7 @@ class PrimaiteGame:
net = sim.network
simulation_config = cfg.get("simulation", {})
defaults_config = cfg.get("defaults", {})
network_config = simulation_config.get("network", {})
airspace_cfg = network_config.get("airspace", {})
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
@@ -338,6 +339,18 @@ class PrimaiteGame:
_LOGGER.error(msg)
raise ValueError(msg)
# TODO: handle simulation defaults more cleanly
if "node_start_up_duration" in defaults_config:
new_node.start_up_duration = defaults_config["node_startup_duration"]
if "node_shut_down_duration" in defaults_config:
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
if "node_scan_duration" in defaults_config:
new_node.node_scan_duration = defaults_config["node_scan_duration"]
if "folder_scan_duration" in defaults_config:
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
if "folder_restore_duration" in defaults_config:
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
@@ -384,6 +397,15 @@ class PrimaiteGame:
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# TODO: handle simulation defaults more cleanly
if "service_fix_duration" in defaults_config:
new_service.fixing_duration = defaults_config["service_fix_duration"]
if "service_restart_duration" in defaults_config:
new_service.restart_duration = defaults_config["service_restart_duration"]
if "service_install_duration" in defaults_config:
new_service.install_duration = defaults_config["service_install_duration"]
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:

View File

@@ -11,6 +11,15 @@
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -15,6 +15,15 @@
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -371,6 +371,15 @@
"First, load the required modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -9,6 +9,15 @@
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -25,6 +25,15 @@
"Let's set up a minimal network simulation and send some requests to see how it works."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -18,6 +18,15 @@
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -18,6 +18,15 @@
"#### First, Import packages and read our config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -32,8 +41,6 @@
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",

View File

@@ -11,6 +11,15 @@
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -18,6 +18,15 @@
"#### First, we import the inital packages and read in our configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 110 KiB

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

After

Width:  |  Height:  |  Size: 69 KiB

View File

@@ -18,6 +18,15 @@
"Import packages and read config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -30,6 +30,11 @@ class FileSystem(SimComponent):
num_file_deletions: int = 0
"Number of file deletions in the current step."
_default_folder_scan_duration: Optional[int] = None
"Override default scan duration for folders"
_default_folder_restore_duration: Optional[int] = None
"Override default restore duration for folders"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Ensure a default root folder
@@ -258,6 +263,11 @@ class FileSystem(SimComponent):
name=folder.name, request_type=RequestType(func=folder._request_manager)
)
self.folders[folder.uuid] = folder
# set the folder scan and restore durations.
if self._default_folder_scan_duration is not None:
folder.scan_duration = self._default_folder_scan_duration
if self._default_folder_restore_duration is not None:
folder.restore_duration = self._default_folder_restore_duration
return folder
def delete_folder(self, folder_name: str) -> bool:

View File

@@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str:
class FileSystemItemHealthStatus(Enum):
"""Status of the FileSystemItem."""
NONE = 0
"""File system item health status is not known."""
GOOD = 1
"""File/Folder is OK."""
@@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent):
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
"Actual status of the current FileSystemItem"
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE
"Visible status of the current FileSystemItem"
previous_hash: Optional[str] = None

View File

@@ -46,7 +46,7 @@ class Folder(FileSystemItemABC):
:param sys_log: The SysLog instance to us to create system logs.
"""
super().__init__(**kwargs)
self._scanned_this_step: bool = False
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def _init_request_manager(self) -> RequestManager:
@@ -83,6 +83,7 @@ class Folder(FileSystemItemABC):
state = super().describe_state()
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()}
state["scanned_this_step"] = self._scanned_this_step
return state
def show(self, markdown: bool = False):
@@ -135,7 +136,7 @@ class Folder(FileSystemItemABC):
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self._scanned_this_step = False
for file in self.files.values():
file.pre_timestep(timestep)
@@ -148,9 +149,17 @@ class Folder(FileSystemItemABC):
for file_id in self.files:
file = self.get_file_by_id(file_uuid=file_id)
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.CORRUPT
# set folder health to worst file's health by generating a list of file healths. If no files, use 0
self.health_status = FileSystemItemHealthStatus(
max(
[f.health_status.value for f in self.files.values()]
or [
0,
]
)
)
self.visible_health_status = self.health_status
self._scanned_this_step = True
def _reveal_to_red_timestep(self) -> None:
"""Apply reveal to red timestep."""

View File

@@ -118,6 +118,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"):
session_id: Optional[str] = None,
is_reattempt: Optional[bool] = False,
) -> bool:
self._active = True
"""
Connects the client to a given FTP server.
@@ -174,6 +175,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"):
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
:type: is_reattempt: Optional[bool]
"""
self._active = True
# send a disconnect request payload to FTP server
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
software_manager: SoftwareManager = self.software_manager
@@ -219,6 +221,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"):
:param: session_id: The id of the session
:type: session_id: Optional[str]
"""
self._active = True
# check if the file to transfer exists on the client
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
if not file_to_transfer:
@@ -276,6 +279,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"):
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
:type: dest_port: Optional[int]
"""
self._active = True
# check if FTP is currently connected to IP
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
@@ -327,6 +331,7 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"):
This helps prevent an FTP request loop - FTP client and servers can exist on
the same node.
"""
self._active = True
if not self._can_perform_action():
return False

View File

@@ -3,9 +3,11 @@ from abc import ABC
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import StrictBool
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.utils.validation.port import Port
@@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC):
Contains shared methods between both classes.
"""
_active: StrictBool = False
"""Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state."""
def pre_timestep(self, timestep: int) -> None:
"""When a new timestep begins, clear the _active attribute."""
self._active = False
return super().pre_timestep(timestep)
def describe_state(self) -> Dict:
"""Returns a Dict of the FTPService state."""
return super().describe_state()
state = super().describe_state()
# override so that the service is shows as running only if actively transmitting data this timestep
if self.operating_state == ServiceOperatingState.RUNNING and not self._active:
state["operating_state"] = ServiceOperatingState.STOPPED.value
return state
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
@@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC):
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
"""
self._active = True
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
@@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC):
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
self._active = True
try:
file_name = payload.ftp_command_args["dest_file_name"]
folder_name = payload.ftp_command_args["dest_folder_name"]
@@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC):
:param: is_response: is true if the data being sent is in response to a request. Default False.
:type: is_response: bool
"""
self._active = True
# send STOR request
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
@@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC):
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
self._active = True
try:
# find the file
file_name = payload.ftp_command_args["src_file_name"]
@@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC):
:return: True if successful, False otherwise.
"""
self._active = True
self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}")
return super().send(