Merge branch 'dev' into feature/2777_set_RNG_seed
This commit is contained in:
@@ -129,6 +129,10 @@ agents:
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
|
||||
@@ -294,7 +294,7 @@ class ConfigureDoSBotAction(AbstractAction):
|
||||
"""Action which sets config parameters for a DoS bot on a node."""
|
||||
|
||||
class _Opts(BaseModel):
|
||||
"""Schema for options that can be passed to this option."""
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
target_ip_address: Optional[str] = None
|
||||
|
||||
@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.airspace import AirSpaceFrequency
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
|
||||
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -266,9 +266,12 @@ class PrimaiteGame:
|
||||
|
||||
nodes_cfg = network_config.get("nodes", [])
|
||||
links_cfg = network_config.get("links", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
new_node = None
|
||||
if n_type == "computer":
|
||||
new_node = Computer(
|
||||
hostname=node_cfg["hostname"],
|
||||
@@ -318,6 +321,11 @@ class PrimaiteGame:
|
||||
msg = f"invalid node type {n_type} in config"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
|
||||
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
|
||||
for user_cfg in node_cfg["users"]:
|
||||
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
new_service = None
|
||||
@@ -535,10 +543,7 @@ class PrimaiteGame:
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
|
||||
@@ -101,7 +101,6 @@
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"import yaml\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
|
||||
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
|
||||
]
|
||||
@@ -135,8 +134,7 @@
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
"results = algo.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -191,8 +189,7 @@
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
"results = algo.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -212,7 +209,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -24,14 +24,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"\n",
|
||||
"import ray\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
@@ -72,7 +69,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Set training parameters and start the training\n",
|
||||
"#### Start the training\n",
|
||||
"This example will save outputs to a default Ray directory and use mostly default settings."
|
||||
]
|
||||
},
|
||||
@@ -82,13 +79,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128},\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
"algo = config.build()\n",
|
||||
"results = algo.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -17,12 +17,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"\n",
|
||||
@@ -64,7 +62,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Set training parameters and start the training"
|
||||
"#### Start the training"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -73,13 +71,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
"algo = config.build()\n",
|
||||
"results = algo.train()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -6,12 +6,11 @@ import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
|
||||
import primaite.simulator.network.nmne
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -20,17 +19,10 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.nmne import (
|
||||
CAPTURE_BY_DIRECTION,
|
||||
CAPTURE_BY_IP_ADDRESS,
|
||||
CAPTURE_BY_KEYWORD,
|
||||
CAPTURE_BY_PORT,
|
||||
CAPTURE_BY_PROTOCOL,
|
||||
CAPTURE_NMNE,
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
@@ -38,7 +30,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
from primaite.simulator.system.software import IOSoftware, Software
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
@@ -108,8 +100,11 @@ class NetworkInterface(SimComponent, ABC):
|
||||
pcap: Optional[PacketCapture] = None
|
||||
"A PacketCapture instance for capturing and analysing packets passing through this interface."
|
||||
|
||||
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
|
||||
"A dataclass defining malicious network events to be captured."
|
||||
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
"A dict containing details of the number of malicious events captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
@@ -167,8 +162,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
if self.nmne_config and self.nmne_config.capture_nmne:
|
||||
state.update({"nmne": self.nmne})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@@ -201,7 +196,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
|
||||
"""
|
||||
# Exit function if NMNE capturing is disabled
|
||||
if not CAPTURE_NMNE:
|
||||
if not (self.nmne_config and self.nmne_config.capture_nmne):
|
||||
return
|
||||
|
||||
# Initialise basic frame data variables
|
||||
@@ -222,27 +217,27 @@ class NetworkInterface(SimComponent, ABC):
|
||||
frame_str = str(frame.payload)
|
||||
|
||||
# Proceed only if any NMNE keyword is present in the frame payload
|
||||
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
|
||||
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
|
||||
# Start with the root of the NMNE capture structure
|
||||
current_level = self.nmne
|
||||
|
||||
# Update NMNE structure based on enabled settings
|
||||
if CAPTURE_BY_DIRECTION:
|
||||
if self.nmne_config.capture_by_direction:
|
||||
# Set or get the dictionary for the current direction
|
||||
current_level = current_level.setdefault("direction", {})
|
||||
current_level = current_level.setdefault(direction, {})
|
||||
|
||||
if CAPTURE_BY_IP_ADDRESS:
|
||||
if self.nmne_config.capture_by_ip_address:
|
||||
# Set or get the dictionary for the current IP address
|
||||
current_level = current_level.setdefault("ip_address", {})
|
||||
current_level = current_level.setdefault(ip_address, {})
|
||||
|
||||
if CAPTURE_BY_PROTOCOL:
|
||||
if self.nmne_config.capture_by_protocol:
|
||||
# Set or get the dictionary for the current protocol
|
||||
current_level = current_level.setdefault("protocol", {})
|
||||
current_level = current_level.setdefault(protocol, {})
|
||||
|
||||
if CAPTURE_BY_PORT:
|
||||
if self.nmne_config.capture_by_port:
|
||||
# Set or get the dictionary for the current port
|
||||
current_level = current_level.setdefault("port", {})
|
||||
current_level = current_level.setdefault(port, {})
|
||||
@@ -251,8 +246,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
keyword_level = current_level.setdefault("keywords", {})
|
||||
|
||||
# Increment the count for detected keywords in the payload
|
||||
if CAPTURE_BY_KEYWORD:
|
||||
for keyword in NMNE_CAPTURE_KEYWORDS:
|
||||
if self.nmne_config.capture_by_keyword:
|
||||
for keyword in self.nmne_config.nmne_capture_keywords:
|
||||
if keyword in frame_str:
|
||||
# Update the count for each keyword found
|
||||
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
|
||||
@@ -794,6 +789,650 @@ class Link(SimComponent):
|
||||
self.current_load = 0.0
|
||||
|
||||
|
||||
class User(SimComponent):
|
||||
"""
|
||||
Represents a user in the PrimAITE system.
|
||||
|
||||
:ivar username: The username of the user
|
||||
:ivar password: The password of the user
|
||||
:ivar disabled: Boolean flag indicating whether the user is disabled
|
||||
:ivar is_admin: Boolean flag indicating whether the user has admin privileges
|
||||
"""
|
||||
|
||||
username: str
|
||||
"""The username of the user"""
|
||||
|
||||
password: str
|
||||
"""The password of the user"""
|
||||
|
||||
disabled: bool = False
|
||||
"""Boolean flag indicating whether the user is disabled"""
|
||||
|
||||
is_admin: bool = False
|
||||
"""Boolean flag indicating whether the user has admin privileges"""
|
||||
|
||||
num_of_logins: int = 0
|
||||
"""Counts the number of the User has logged in"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Returns a dictionary representing the current state of the user.
|
||||
|
||||
:return: A dict containing the state of the user
|
||||
"""
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class UserManager(Service):
|
||||
"""
|
||||
Manages users within the PrimAITE system, handling creation, authentication, and administration.
|
||||
|
||||
:param users: A dictionary of all users by their usernames
|
||||
:param admins: A dictionary of admin users by their usernames
|
||||
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
|
||||
"""
|
||||
|
||||
users: Dict[str, User] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initializes a UserManager instanc.
|
||||
|
||||
:param username: The username for the default admin user
|
||||
:param password: The password for the default admin user
|
||||
"""
|
||||
kwargs["name"] = "UserManager"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.start()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
# todo add doc about requeest schemas
|
||||
rm.add_request(
|
||||
"change_password",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.change_user_password(username=request[0], current_password=request[1], new_password=request[2])
|
||||
)
|
||||
),
|
||||
)
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Returns the state of the UserManager along with the number of users and admins.
|
||||
|
||||
:return: A dict containing detailed state information
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)})
|
||||
state["users"] = {k: v.describe_state() for k, v in self.users.items()}
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Display the Users.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(["Username", "Admin", "Disabled"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} User Manager"
|
||||
for user in self.users.values():
|
||||
table.add_row([user.username, user.is_admin, user.disabled])
|
||||
print(table.get_string(sortby="Username"))
|
||||
|
||||
@property
|
||||
def non_admins(self) -> Dict[str, User]:
|
||||
"""
|
||||
Returns a dictionary of all enabled non-admin users.
|
||||
|
||||
:return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users.
|
||||
"""
|
||||
return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled}
|
||||
|
||||
@property
|
||||
def disabled_non_admins(self) -> Dict[str, User]:
|
||||
"""
|
||||
Returns a dictionary of all disabled non-admin users.
|
||||
|
||||
:return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users.
|
||||
"""
|
||||
return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled}
|
||||
|
||||
@property
|
||||
def admins(self) -> Dict[str, User]:
|
||||
"""
|
||||
Returns a dictionary of all enabled admin users.
|
||||
|
||||
:return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users.
|
||||
"""
|
||||
return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled}
|
||||
|
||||
@property
|
||||
def disabled_admins(self) -> Dict[str, User]:
|
||||
"""
|
||||
Returns a dictionary of all disabled admin users.
|
||||
|
||||
:return: A dictionary with usernames as keys and User objects as values for admin, disabled users.
|
||||
"""
|
||||
return {k: v for k, v in self.users.items() if v.is_admin and v.disabled}
|
||||
|
||||
def install(self) -> None:
|
||||
"""Setup default user during first-time installation."""
|
||||
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
|
||||
|
||||
def _is_last_admin(self, username: str) -> bool:
|
||||
return username in self.admins and len(self.admins) == 1
|
||||
|
||||
def add_user(
|
||||
self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a new user to the system.
|
||||
|
||||
:param username: The username for the new user
|
||||
:param password: The password for the new user
|
||||
:param is_admin: Flag indicating if the new user is an admin
|
||||
:return: True if user was successfully added, False otherwise
|
||||
"""
|
||||
if not bypass_can_perform_action and not self._can_perform_action():
|
||||
return False
|
||||
if username in self.users:
|
||||
self.sys_log.info(f"{self.name}: Failed to create new user {username} as this user name already exists")
|
||||
return False
|
||||
user = User(username=username, password=password, is_admin=is_admin)
|
||||
self.users[username] = user
|
||||
self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}")
|
||||
return True
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticates a user's login attempt.
|
||||
|
||||
:param username: The username of the user trying to log in
|
||||
:param password: The password provided by the user
|
||||
:return: The User object if authentication is successful, None otherwise
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
user = self.users.get(username)
|
||||
if user and not user.disabled and user.password == password:
|
||||
self.sys_log.info(f"{self.name}: User authenticated: {username}")
|
||||
return user
|
||||
self.sys_log.info(f"{self.name}: Authentication failed for: {username}")
|
||||
return None
|
||||
|
||||
def change_user_password(self, username: str, current_password: str, new_password: str) -> bool:
|
||||
"""
|
||||
Changes a user's password.
|
||||
|
||||
:param username: The username of the user changing their password
|
||||
:param current_password: The current password of the user
|
||||
:param new_password: The new password for the user
|
||||
:return: True if the password was changed successfully, False otherwise
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
user = self.users.get(username)
|
||||
if user and user.password == current_password:
|
||||
user.password = new_password
|
||||
self.sys_log.info(f"{self.name}: Password changed for {username}")
|
||||
return True
|
||||
self.sys_log.info(f"{self.name}: Password change failed for {username}")
|
||||
return False
|
||||
|
||||
def disable_user(self, username: str) -> bool:
|
||||
"""
|
||||
Disables a user account, preventing them from logging in.
|
||||
|
||||
:param username: The username of the user to disable
|
||||
:return: True if the user was disabled successfully, False otherwise
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
if username in self.users and not self.users[username].disabled:
|
||||
if self._is_last_admin(username):
|
||||
self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin")
|
||||
return False
|
||||
self.users[username].disabled = True
|
||||
self.sys_log.info(f"{self.name}: User disabled: {username}")
|
||||
return True
|
||||
self.sys_log.info(f"{self.name}: Failed to disable user: {username}")
|
||||
return False
|
||||
|
||||
def enable_user(self, username: str) -> bool:
|
||||
"""
|
||||
Enables a previously disabled user account.
|
||||
|
||||
:param username: The username of the user to enable
|
||||
:return: True if the user was enabled successfully, False otherwise
|
||||
"""
|
||||
if username in self.users and self.users[username].disabled:
|
||||
self.users[username].disabled = False
|
||||
self.sys_log.info(f"{self.name}: User enabled: {username}")
|
||||
return True
|
||||
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
|
||||
return False
|
||||
|
||||
|
||||
class UserSession(SimComponent):
|
||||
"""
|
||||
Represents a user session on the Node.
|
||||
|
||||
This class manages the state of a user session, including the user, session start, last active step,
|
||||
and end step. It also indicates whether the session is local.
|
||||
|
||||
:ivar user: The user associated with this session.
|
||||
:ivar start_step: The timestep when the session was started.
|
||||
:ivar last_active_step: The last timestep when the session was active.
|
||||
:ivar end_step: The timestep when the session ended, if applicable.
|
||||
:ivar local: Indicates if the session is local. Defaults to True.
|
||||
"""
|
||||
|
||||
user: User
|
||||
"""The user associated with this session."""
|
||||
|
||||
start_step: int
|
||||
"""The timestep when the session was started."""
|
||||
|
||||
last_active_step: int
|
||||
"""The last timestep when the session was active."""
|
||||
|
||||
end_step: Optional[int] = None
|
||||
"""The timestep when the session ended, if applicable."""
|
||||
|
||||
local: bool = True
|
||||
"""Indicates if the session is a local session or a remote session. Defaults to True as a local session."""
|
||||
|
||||
@classmethod
|
||||
def create(cls, user: User, timestep: int) -> UserSession:
|
||||
"""
|
||||
Creates a new instance of UserSession.
|
||||
|
||||
This class method initialises a user session with the given user and timestep.
|
||||
|
||||
:param user: The user associated with this session.
|
||||
:param timestep: The timestep when the session is created.
|
||||
:return: An instance of UserSession.
|
||||
"""
|
||||
user.num_of_logins += 1
|
||||
return UserSession(user=user, start_step=timestep, last_active_step=timestep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the user session.
|
||||
|
||||
:return: A dictionary representing the state of the user session.
|
||||
"""
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class RemoteUserSession(UserSession):
|
||||
"""
|
||||
Represents a remote user session on the Node.
|
||||
|
||||
This class extends the UserSession class to include additional attributes and methods specific to remote sessions.
|
||||
|
||||
:ivar remote_ip_address: The IP address of the remote user.
|
||||
:ivar local: Indicates that this is not a local session. Always set to False.
|
||||
"""
|
||||
|
||||
remote_ip_address: IPV4Address
|
||||
"""The IP address of the remote user."""
|
||||
|
||||
local: bool = False
|
||||
"""Indicates that this is not a local session. Always set to False."""
|
||||
|
||||
@classmethod
|
||||
def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa
|
||||
"""
|
||||
Creates a new instance of RemoteUserSession.
|
||||
|
||||
This class method initialises a remote user session with the given user, timestep, and remote IP address.
|
||||
|
||||
:param user: The user associated with this session.
|
||||
:param timestep: The timestep when the session is created.
|
||||
:param remote_ip_address: The IP address of the remote user.
|
||||
:return: An instance of RemoteUserSession.
|
||||
"""
|
||||
return RemoteUserSession(
|
||||
user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address
|
||||
)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the remote user session.
|
||||
|
||||
This method extends the base describe_state method to include the remote IP address.
|
||||
|
||||
:return: A dictionary representing the state of the remote user session.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["remote_ip_address"] = str(self.remote_ip_address)
|
||||
return state
|
||||
|
||||
|
||||
class UserSessionManager(Service):
|
||||
"""
|
||||
Manages user sessions on a Node, including local and remote sessions.
|
||||
|
||||
This class handles authentication, session management, and session timeouts for users interacting with the Node.
|
||||
"""
|
||||
|
||||
local_session: Optional[UserSession] = None
|
||||
"""The current local user session, if any."""
|
||||
|
||||
remote_sessions: Dict[str, RemoteUserSession] = {}
|
||||
"""A dictionary of active remote user sessions."""
|
||||
|
||||
historic_sessions: List[UserSession] = Field(default_factory=list)
|
||||
"""A list of historic user sessions."""
|
||||
|
||||
local_session_timeout_steps: int = 30
|
||||
"""The number of steps before a local session times out due to inactivity."""
|
||||
|
||||
remote_session_timeout_steps: int = 5
|
||||
"""The number of steps before a remote session times out due to inactivity."""
|
||||
|
||||
max_remote_sessions: int = 3
|
||||
"""The maximum number of concurrent remote sessions allowed."""
|
||||
|
||||
current_timestep: int = 0
|
||||
"""The current timestep in the simulation."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initializes a UserSessionManager instance.
|
||||
|
||||
:param username: The username for the default admin user
|
||||
:param password: The password for the default admin user
|
||||
"""
|
||||
kwargs["name"] = "UserSessionManager"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
# todo add doc about requeest schemas
|
||||
rm.add_request(
|
||||
"remote_login",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.remote_login(username=request[0], password=request[1], remote_ip_address=request[2])
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"remote_logout",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.remote_logout(remote_session_id=request[0])
|
||||
)
|
||||
),
|
||||
)
|
||||
return rm
|
||||
|
||||
def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False):
|
||||
"""
|
||||
Displays a table of the user sessions on the Node.
|
||||
|
||||
:param markdown: Whether to display the table in markdown format.
|
||||
:param include_session_id: Whether to include session IDs in the table.
|
||||
:param include_historic: Whether to include historic sessions in the table.
|
||||
"""
|
||||
headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"]
|
||||
|
||||
if not include_session_id:
|
||||
headers = headers[1:]
|
||||
|
||||
table = PrettyTable(headers)
|
||||
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.parent.hostname} User Sessions"
|
||||
|
||||
def _add_session_to_table(user_session: UserSession):
|
||||
"""
|
||||
Adds a user session to the table for display.
|
||||
|
||||
This helper function determines whether the session is local or remote and formats the session data
|
||||
accordingly. It then adds the session data to the table.
|
||||
|
||||
:param user_session: The user session to add to the table.
|
||||
"""
|
||||
session_type = "local"
|
||||
remote_ip = ""
|
||||
if isinstance(user_session, RemoteUserSession):
|
||||
session_type = "remote"
|
||||
remote_ip = str(user_session.remote_ip_address)
|
||||
data = [
|
||||
user_session.uuid,
|
||||
user_session.user.username,
|
||||
session_type,
|
||||
remote_ip,
|
||||
user_session.start_step,
|
||||
user_session.last_active_step,
|
||||
user_session.end_step if user_session.end_step else "",
|
||||
]
|
||||
if not include_session_id:
|
||||
data = data[1:]
|
||||
table.add_row(data)
|
||||
|
||||
if self.local_session is not None:
|
||||
_add_session_to_table(self.local_session)
|
||||
|
||||
for user_session in self.remote_sessions.values():
|
||||
_add_session_to_table(user_session)
|
||||
|
||||
if include_historic:
|
||||
for user_session in self.historic_sessions:
|
||||
_add_session_to_table(user_session)
|
||||
|
||||
print(table.get_string(sortby="Step Last Active", reversesort=True))
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the UserSessionManager.
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["active_remote_logins"] = len(self.remote_sessions)
|
||||
return state
|
||||
|
||||
@property
|
||||
def _user_manager(self) -> UserManager:
|
||||
"""
|
||||
Returns the UserManager instance.
|
||||
|
||||
:return: The UserManager instance.
|
||||
"""
|
||||
return self.software_manager.software["UserManager"] # noqa
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
|
||||
self.current_timestep = timestep
|
||||
if self.local_session:
|
||||
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
|
||||
self._timeout_session(self.local_session)
|
||||
|
||||
def _timeout_session(self, session: UserSession) -> None:
|
||||
"""
|
||||
Handles session timeout logic.
|
||||
|
||||
:param session: The session to be timed out.
|
||||
"""
|
||||
session.end_step = self.current_timestep
|
||||
session_identity = session.user.username
|
||||
if session.local:
|
||||
self.local_session = None
|
||||
session_type = "Local"
|
||||
else:
|
||||
self.remote_sessions.pop(session.uuid)
|
||||
session_type = "Remote"
|
||||
session_identity = f"{session_identity} {session.remote_ip_address}"
|
||||
|
||||
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")
|
||||
|
||||
@property
|
||||
def remote_session_limit_reached(self) -> bool:
|
||||
"""
|
||||
Checks if the maximum number of remote sessions has been reached.
|
||||
|
||||
:return: True if the limit is reached, otherwise False.
|
||||
"""
|
||||
return len(self.remote_sessions) >= self.max_remote_sessions
|
||||
|
||||
def validate_remote_session_uuid(self, remote_session_id: str) -> bool:
|
||||
"""
|
||||
Validates if a given remote session ID exists.
|
||||
|
||||
:param remote_session_id: The remote session ID to validate.
|
||||
:return: True if the session ID exists, otherwise False.
|
||||
"""
|
||||
return remote_session_id in self.remote_sessions
|
||||
|
||||
def _login(
|
||||
self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Logs a user in either locally or remotely.
|
||||
|
||||
:param username: The username of the account.
|
||||
:param password: The password of the account.
|
||||
:param local: Whether the login is local or remote.
|
||||
:param remote_ip_address: The remote IP address for remote login.
|
||||
:return: The session ID if login is successful, otherwise None.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
|
||||
user = self._user_manager.authenticate_user(username=username, password=password)
|
||||
|
||||
if not user:
|
||||
self.sys_log.info(f"{self.name}: Incorrect username or password")
|
||||
return None
|
||||
|
||||
session_id = None
|
||||
if local:
|
||||
create_new_session = True
|
||||
if self.local_session:
|
||||
if self.local_session.user != user:
|
||||
# logout the current user
|
||||
self.local_logout()
|
||||
else:
|
||||
# not required as existing logged-in user attempting to re-login
|
||||
create_new_session = False
|
||||
|
||||
if create_new_session:
|
||||
self.local_session = UserSession.create(user=user, timestep=self.current_timestep)
|
||||
|
||||
session_id = self.local_session.uuid
|
||||
else:
|
||||
if not self.remote_session_limit_reached:
|
||||
remote_session = RemoteUserSession.create(
|
||||
user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address
|
||||
)
|
||||
session_id = remote_session.uuid
|
||||
self.remote_sessions[session_id] = remote_session
|
||||
self.sys_log.info(f"{self.name}: User {user.username} logged in")
|
||||
return session_id
|
||||
|
||||
def local_login(self, username: str, password: str) -> Optional[str]:
|
||||
"""
|
||||
Logs a user in locally.
|
||||
|
||||
:param username: The username of the account.
|
||||
:param password: The password of the account.
|
||||
:return: The session ID if login is successful, otherwise None.
|
||||
"""
|
||||
return self._login(username=username, password=password, local=True)
|
||||
|
||||
@validate_call()
|
||||
def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]:
|
||||
"""
|
||||
Logs a user in remotely.
|
||||
|
||||
:param username: The username of the account.
|
||||
:param password: The password of the account.
|
||||
:param remote_ip_address: The remote IP address for the remote login.
|
||||
:return: The session ID if login is successful, otherwise None.
|
||||
"""
|
||||
return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address)
|
||||
|
||||
def _logout(self, local: bool = True, remote_session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Logs a user out either locally or remotely.
|
||||
|
||||
:param local: Whether the logout is local or remote.
|
||||
:param remote_session_id: The remote session ID for remote logout.
|
||||
:return: True if logout successful, otherwise False.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
session = None
|
||||
if local and self.local_session:
|
||||
session = self.local_session
|
||||
session.end_step = self.current_timestep
|
||||
self.local_session = None
|
||||
|
||||
if not local and remote_session_id:
|
||||
session = self.remote_sessions.pop(remote_session_id)
|
||||
if session:
|
||||
self.historic_sessions.append(session)
|
||||
self.sys_log.info(f"{self.name}: User {session.user.username} logged out")
|
||||
return True
|
||||
return False
|
||||
|
||||
def local_logout(self) -> bool:
|
||||
"""
|
||||
Logs out the current local user.
|
||||
|
||||
:return: True if logout successful, otherwise False.
|
||||
"""
|
||||
return self._logout(local=True)
|
||||
|
||||
def remote_logout(self, remote_session_id: str) -> bool:
|
||||
"""
|
||||
Logs out a remote user by session ID.
|
||||
|
||||
:param remote_session_id: The remote session ID.
|
||||
:return: True if logout successful, otherwise False.
|
||||
"""
|
||||
return self._logout(local=False, remote_session_id=remote_session_id)
|
||||
|
||||
@property
|
||||
def local_user_logged_in(self) -> bool:
|
||||
"""
|
||||
Checks if a local user is currently logged in.
|
||||
|
||||
:return: True if a local user is logged in, otherwise False.
|
||||
"""
|
||||
return self.local_session is not None
|
||||
|
||||
|
||||
class Node(SimComponent):
|
||||
"""
|
||||
A basic Node class that represents a node on the network.
|
||||
@@ -861,11 +1500,14 @@ class Node(SimComponent):
|
||||
red_scan_countdown: int = 0
|
||||
"Time steps until reveal to red scan is complete."
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
|
||||
"Base system software that must be preinstalled."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the Node with various components and managers.
|
||||
|
||||
This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not
|
||||
This method initialises the ARP cache, ICMP handler, session manager, and software manager if they are not
|
||||
provided.
|
||||
"""
|
||||
if not kwargs.get("sys_log"):
|
||||
@@ -885,9 +1527,40 @@ class Node(SimComponent):
|
||||
dns_server=kwargs.get("dns_server"),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self._install_system_software()
|
||||
self.session_manager.node = self
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self._install_system_software()
|
||||
|
||||
@property
|
||||
def user_manager(self) -> Optional[UserManager]:
|
||||
"""The Nodes User Manager."""
|
||||
return self.software_manager.software.get("UserManager") # noqa
|
||||
|
||||
@property
|
||||
def user_session_manager(self) -> Optional[UserSessionManager]:
|
||||
"""The Nodes User Session Manager."""
|
||||
return self.software_manager.software.get("UserSessionManager") # noqa
|
||||
|
||||
def local_login(self, username: str, password: str) -> Optional[str]:
|
||||
"""
|
||||
Attempt to log in to the node uas a local user.
|
||||
|
||||
This method attempts to authenticate a local user with the given username and password. If successful, it
|
||||
returns a session token. If authentication fails, it returns None.
|
||||
|
||||
:param username: The username of the account attempting to log in.
|
||||
:param password: The password of the account attempting to log in.
|
||||
:return: A session token if the login is successful, otherwise None.
|
||||
"""
|
||||
return self.user_session_manager.local_login(username, password)
|
||||
|
||||
def local_logout(self) -> None:
|
||||
"""
|
||||
Log out the current local user from the node.
|
||||
|
||||
This method ends the current local user's session and invalidates the session token.
|
||||
"""
|
||||
return self.user_session_manager.local_logout()
|
||||
|
||||
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
|
||||
"""
|
||||
@@ -942,7 +1615,7 @@ class Node(SimComponent):
|
||||
@property
|
||||
def fail_message(self) -> str:
|
||||
"""Message that is reported when a request is rejected by this validator."""
|
||||
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
|
||||
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
|
||||
|
||||
class _NodeIsOffValidator(RequestPermissionValidator):
|
||||
"""
|
||||
@@ -984,7 +1657,7 @@ class Node(SimComponent):
|
||||
application_name = request[0]
|
||||
if self.software_manager.software.get(application_name):
|
||||
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(status="success", data={"reason": "already installed"})
|
||||
application_class = Application._application_registry[application_name]
|
||||
self.software_manager.install(application_class)
|
||||
application_instance = self.software_manager.software.get(application_name)
|
||||
@@ -1091,10 +1764,6 @@ class Node(SimComponent):
|
||||
|
||||
return rm
|
||||
|
||||
def _install_system_software(self):
|
||||
"""Install System Software - software that is usually provided with the OS."""
|
||||
pass
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -1173,7 +1842,7 @@ class Node(SimComponent):
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
|
||||
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
@@ -1483,6 +2152,11 @@ class Node(SimComponent):
|
||||
# for process_id in self.processes:
|
||||
# self.processes[process_id]
|
||||
|
||||
def _install_system_software(self) -> None:
|
||||
"""Preinstall required software."""
|
||||
for _, software_class in self.SYSTEM_SOFTWARE.items():
|
||||
self.software_manager.install(software_class)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
if isinstance(item, Service):
|
||||
return item.uuid in self.services
|
||||
|
||||
@@ -5,7 +5,13 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
|
||||
from primaite.simulator.network.hardware.base import (
|
||||
IPWiredNetworkInterface,
|
||||
Link,
|
||||
Node,
|
||||
UserManager,
|
||||
UserSessionManager,
|
||||
)
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
@@ -306,6 +312,8 @@ class HostNode(Node):
|
||||
"NTPClient": NTPClient,
|
||||
"WebBrowser": WebBrowser,
|
||||
"NMAP": NMAP,
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
}
|
||||
"""List of system software that is automatically installed on nodes."""
|
||||
|
||||
@@ -338,18 +346,6 @@ class HostNode(Node):
|
||||
"""
|
||||
return self.software_manager.software.get("ARP")
|
||||
|
||||
def _install_system_software(self):
|
||||
"""
|
||||
Installs the system software and network services typically found on an operating system.
|
||||
|
||||
This method equips the host with essential network services and applications, preparing it for various
|
||||
network-related tasks and operations.
|
||||
"""
|
||||
for _, software_class in self.SYSTEM_SOFTWARE.items():
|
||||
self.software_manager.install(software_class)
|
||||
|
||||
super()._install_system_software()
|
||||
|
||||
def default_gateway_hello(self):
|
||||
"""
|
||||
Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address.
|
||||
|
||||
@@ -4,14 +4,14 @@ from __future__ import annotations
|
||||
import secrets
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
|
||||
from primaite.simulator.network.protocols.arp import ARPPacket
|
||||
@@ -1200,6 +1200,11 @@ class Router(NetworkNode):
|
||||
RouteTable, RouterARP, and RouterICMP services.
|
||||
"""
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
}
|
||||
|
||||
num_ports: int
|
||||
network_interfaces: Dict[str, RouterInterface] = {}
|
||||
"The Router Interfaces on the node."
|
||||
@@ -1235,6 +1240,7 @@ class Router(NetworkNode):
|
||||
resolution within the network. These services are crucial for the router's operation, enabling it to manage
|
||||
network traffic efficiently.
|
||||
"""
|
||||
super()._install_system_software()
|
||||
self.software_manager.install(RouterICMP)
|
||||
icmp: RouterICMP = self.software_manager.icmp # noqa
|
||||
icmp.router = self
|
||||
|
||||
@@ -108,6 +108,9 @@ class Switch(NetworkNode):
|
||||
for i in range(1, self.num_ports + 1):
|
||||
self.connect_nic(SwitchPort())
|
||||
|
||||
def _install_system_software(self):
|
||||
pass
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Prints a table of the SwitchPorts on the Switch.
|
||||
|
||||
@@ -1,48 +1,25 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, Final, List
|
||||
from typing import List
|
||||
|
||||
CAPTURE_NMNE: bool = True
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
|
||||
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination IP address."""
|
||||
CAPTURE_BY_PROTOCOL: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
CAPTURE_BY_PORT: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination port."""
|
||||
CAPTURE_BY_KEYWORD: Final[bool] = False
|
||||
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def set_nmne_config(nmne_config: Dict):
|
||||
"""
|
||||
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
|
||||
class NMNEConfig(BaseModel):
|
||||
"""Store all the information to perform NMNE operations."""
|
||||
|
||||
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
|
||||
keywords to use for identifying NMNEs.
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
|
||||
and maintains type integrity by checking the types of the provided values.
|
||||
|
||||
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
|
||||
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
|
||||
to specify keywords for NMNE identification.
|
||||
"""
|
||||
global NMNE_CAPTURE_KEYWORDS
|
||||
global CAPTURE_NMNE
|
||||
|
||||
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
|
||||
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
|
||||
if not isinstance(CAPTURE_NMNE, bool):
|
||||
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
|
||||
|
||||
# Update the NMNE capture keywords, appending new keywords if provided
|
||||
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
|
||||
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
|
||||
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
|
||||
capture_nmne: bool = False
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
|
||||
nmne_capture_keywords: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
capture_by_direction: bool = True
|
||||
"""Captures should be organized by traffic direction (inbound/outbound)."""
|
||||
capture_by_ip_address: bool = False
|
||||
"""Captures should be organized by source or destination IP address."""
|
||||
capture_by_protocol: bool = False
|
||||
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
capture_by_port: bool = False
|
||||
"""Captures should be organized by source or destination port."""
|
||||
capture_by_keyword: bool = False
|
||||
"""Captures should be filtered and categorised based on specific keywords."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, field_validator, validate_call
|
||||
from pydantic_core.core_schema import FieldValidationInfo
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
|
||||
|
||||
@field_validator("icmp_code") # noqa
|
||||
@classmethod
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
|
||||
"""Validates the icmp_type and icmp_code."""
|
||||
icmp_type = info.data["icmp_type"]
|
||||
if get_icmp_type_code_description(icmp_type, v):
|
||||
|
||||
@@ -103,7 +103,7 @@ class SoftwareManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftware]):
|
||||
def install(self, software_class: Type[IOSoftware], **install_kwargs):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
|
||||
@@ -113,7 +113,11 @@ class SoftwareManager:
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
**install_kwargs,
|
||||
)
|
||||
software.parent = self.node
|
||||
if isinstance(software, Application):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -291,7 +291,7 @@ class IOSoftware(Software):
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
self.software_manager.node.sys_log.error(
|
||||
f"{self.name} Error: {self.software_manager.node.hostname} is not online."
|
||||
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user