Merge branch 'dev' into feature/2777_set_RNG_seed

This commit is contained in:
Nick Todd
2024-08-05 11:12:30 +01:00
39 changed files with 1552 additions and 382 deletions

View File

@@ -129,6 +129,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: client
type: computer

View File

@@ -294,7 +294,7 @@ class ConfigureDoSBotAction(AbstractAction):
"""Action which sets config parameters for a DoS bot on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this option."""
"""Schema for options that can be passed to this action."""
model_config = ConfigDict(extra="forbid")
target_ip_address: Optional[str] = None

View File

@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -266,9 +266,12 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
new_node = None
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
@@ -318,6 +321,11 @@ class PrimaiteGame:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
raise ValueError(msg)
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
@@ -535,10 +543,7 @@ class PrimaiteGame:
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):

View File

@@ -101,7 +101,6 @@
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
@@ -135,8 +134,7 @@
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
"results = algo.train()"
]
},
{
@@ -191,8 +189,7 @@
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
"results = algo.train()"
]
}
],
@@ -212,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -24,14 +24,11 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
@@ -72,7 +69,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training\n",
"#### Start the training\n",
"This example will save outputs to a default Ray directory and use mostly default settings."
]
},
@@ -82,13 +79,8 @@
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128},\n",
" ),\n",
" param_space=config\n",
").fit()"
"algo = config.build()\n",
"results = algo.train()"
]
}
],

View File

@@ -17,12 +17,10 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"\n",
@@ -64,7 +62,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training"
"#### Start the training"
]
},
{
@@ -73,13 +71,8 @@
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
"algo = config.build()\n",
"results = algo.train()\n"
]
}
],

View File

@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validate_call
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,17 +19,10 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
from primaite.simulator.system.core.session_manager import SessionManager
@@ -38,7 +30,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.simulator.system.software import IOSoftware, Software
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
@@ -108,8 +100,11 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
"A dataclass defining malicious network events to be captured."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
"A dict containing details of the number of malicious events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -167,8 +162,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": self.nmne})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -201,7 +196,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -222,27 +217,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -251,8 +246,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -794,6 +789,650 @@ class Link(SimComponent):
self.current_load = 0.0
class User(SimComponent):
"""
Represents a user in the PrimAITE system.
:ivar username: The username of the user
:ivar password: The password of the user
:ivar disabled: Boolean flag indicating whether the user is disabled
:ivar is_admin: Boolean flag indicating whether the user has admin privileges
"""
username: str
"""The username of the user"""
password: str
"""The password of the user"""
disabled: bool = False
"""Boolean flag indicating whether the user is disabled"""
is_admin: bool = False
"""Boolean flag indicating whether the user has admin privileges"""
num_of_logins: int = 0
"""Counts the number of the User has logged in"""
def describe_state(self) -> Dict:
"""
Returns a dictionary representing the current state of the user.
:return: A dict containing the state of the user
"""
return self.model_dump()
class UserManager(Service):
"""
Manages users within the PrimAITE system, handling creation, authentication, and administration.
:param users: A dictionary of all users by their usernames
:param admins: A dictionary of admin users by their usernames
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
Initializes a UserManager instanc.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
rm.add_request(
"change_password",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.change_user_password(username=request[0], current_password=request[1], new_password=request[2])
)
),
)
return rm
def describe_state(self) -> Dict:
"""
Returns the state of the UserManager along with the number of users and admins.
:return: A dict containing detailed state information
"""
state = super().describe_state()
state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)})
state["users"] = {k: v.describe_state() for k, v in self.users.items()}
return state
def show(self, markdown: bool = False):
"""
Display the Users.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["Username", "Admin", "Disabled"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} User Manager"
for user in self.users.values():
table.add_row([user.username, user.is_admin, user.disabled])
print(table.get_string(sortby="Username"))
@property
def non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled}
@property
def disabled_non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled}
@property
def admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled}
@property
def disabled_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and v.disabled}
def install(self) -> None:
"""Setup default user during first-time installation."""
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
def _is_last_admin(self, username: str) -> bool:
return username in self.admins and len(self.admins) == 1
def add_user(
self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False
) -> bool:
"""
Adds a new user to the system.
:param username: The username for the new user
:param password: The password for the new user
:param is_admin: Flag indicating if the new user is an admin
:return: True if user was successfully added, False otherwise
"""
if not bypass_can_perform_action and not self._can_perform_action():
return False
if username in self.users:
self.sys_log.info(f"{self.name}: Failed to create new user {username} as this user name already exists")
return False
user = User(username=username, password=password, is_admin=is_admin)
self.users[username] = user
self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}")
return True
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""
Authenticates a user's login attempt.
:param username: The username of the user trying to log in
:param password: The password provided by the user
:return: The User object if authentication is successful, None otherwise
"""
if not self._can_perform_action():
return None
user = self.users.get(username)
if user and not user.disabled and user.password == password:
self.sys_log.info(f"{self.name}: User authenticated: {username}")
return user
self.sys_log.info(f"{self.name}: Authentication failed for: {username}")
return None
def change_user_password(self, username: str, current_password: str, new_password: str) -> bool:
"""
Changes a user's password.
:param username: The username of the user changing their password
:param current_password: The current password of the user
:param new_password: The new password for the user
:return: True if the password was changed successfully, False otherwise
"""
if not self._can_perform_action():
return False
user = self.users.get(username)
if user and user.password == current_password:
user.password = new_password
self.sys_log.info(f"{self.name}: Password changed for {username}")
return True
self.sys_log.info(f"{self.name}: Password change failed for {username}")
return False
def disable_user(self, username: str) -> bool:
"""
Disables a user account, preventing them from logging in.
:param username: The username of the user to disable
:return: True if the user was disabled successfully, False otherwise
"""
if not self._can_perform_action():
return False
if username in self.users and not self.users[username].disabled:
if self._is_last_admin(username):
self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin")
return False
self.users[username].disabled = True
self.sys_log.info(f"{self.name}: User disabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to disable user: {username}")
return False
def enable_user(self, username: str) -> bool:
"""
Enables a previously disabled user account.
:param username: The username of the user to enable
:return: True if the user was enabled successfully, False otherwise
"""
if username in self.users and self.users[username].disabled:
self.users[username].disabled = False
self.sys_log.info(f"{self.name}: User enabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
class UserSession(SimComponent):
"""
Represents a user session on the Node.
This class manages the state of a user session, including the user, session start, last active step,
and end step. It also indicates whether the session is local.
:ivar user: The user associated with this session.
:ivar start_step: The timestep when the session was started.
:ivar last_active_step: The last timestep when the session was active.
:ivar end_step: The timestep when the session ended, if applicable.
:ivar local: Indicates if the session is local. Defaults to True.
"""
user: User
"""The user associated with this session."""
start_step: int
"""The timestep when the session was started."""
last_active_step: int
"""The last timestep when the session was active."""
end_step: Optional[int] = None
"""The timestep when the session ended, if applicable."""
local: bool = True
"""Indicates if the session is a local session or a remote session. Defaults to True as a local session."""
@classmethod
def create(cls, user: User, timestep: int) -> UserSession:
"""
Creates a new instance of UserSession.
This class method initialises a user session with the given user and timestep.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:return: An instance of UserSession.
"""
user.num_of_logins += 1
return UserSession(user=user, start_step=timestep, last_active_step=timestep)
def describe_state(self) -> Dict:
"""
Describes the current state of the user session.
:return: A dictionary representing the state of the user session.
"""
return self.model_dump()
class RemoteUserSession(UserSession):
"""
Represents a remote user session on the Node.
This class extends the UserSession class to include additional attributes and methods specific to remote sessions.
:ivar remote_ip_address: The IP address of the remote user.
:ivar local: Indicates that this is not a local session. Always set to False.
"""
remote_ip_address: IPV4Address
"""The IP address of the remote user."""
local: bool = False
"""Indicates that this is not a local session. Always set to False."""
@classmethod
def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa
"""
Creates a new instance of RemoteUserSession.
This class method initialises a remote user session with the given user, timestep, and remote IP address.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:param remote_ip_address: The IP address of the remote user.
:return: An instance of RemoteUserSession.
"""
return RemoteUserSession(
user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address
)
def describe_state(self) -> Dict:
"""
Describes the current state of the remote user session.
This method extends the base describe_state method to include the remote IP address.
:return: A dictionary representing the state of the remote user session.
"""
state = super().describe_state()
state["remote_ip_address"] = str(self.remote_ip_address)
return state
class UserSessionManager(Service):
"""
Manages user sessions on a Node, including local and remote sessions.
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
remote_sessions: Dict[str, RemoteUserSession] = {}
"""A dictionary of active remote user sessions."""
historic_sessions: List[UserSession] = Field(default_factory=list)
"""A list of historic user sessions."""
local_session_timeout_steps: int = 30
"""The number of steps before a local session times out due to inactivity."""
remote_session_timeout_steps: int = 5
"""The number of steps before a remote session times out due to inactivity."""
max_remote_sessions: int = 3
"""The maximum number of concurrent remote sessions allowed."""
current_timestep: int = 0
"""The current timestep in the simulation."""
def __init__(self, **kwargs):
"""
Initializes a UserSessionManager instance.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserSessionManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
rm.add_request(
"remote_login",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.remote_login(username=request[0], password=request[1], remote_ip_address=request[2])
)
),
)
rm.add_request(
"remote_logout",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.remote_logout(remote_session_id=request[0])
)
),
)
return rm
def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False):
"""
Displays a table of the user sessions on the Node.
:param markdown: Whether to display the table in markdown format.
:param include_session_id: Whether to include session IDs in the table.
:param include_historic: Whether to include historic sessions in the table.
"""
headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"]
if not include_session_id:
headers = headers[1:]
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""
Adds a user session to the table for display.
This helper function determines whether the session is local or remote and formats the session data
accordingly. It then adds the session data to the table.
:param user_session: The user session to add to the table.
"""
session_type = "local"
remote_ip = ""
if isinstance(user_session, RemoteUserSession):
session_type = "remote"
remote_ip = str(user_session.remote_ip_address)
data = [
user_session.uuid,
user_session.user.username,
session_type,
remote_ip,
user_session.start_step,
user_session.last_active_step,
user_session.end_step if user_session.end_step else "",
]
if not include_session_id:
data = data[1:]
table.add_row(data)
if self.local_session is not None:
_add_session_to_table(self.local_session)
for user_session in self.remote_sessions.values():
_add_session_to_table(user_session)
if include_historic:
for user_session in self.historic_sessions:
_add_session_to_table(user_session)
print(table.get_string(sortby="Step Last Active", reversesort=True))
def describe_state(self) -> Dict:
"""
Describes the current state of the UserSessionManager.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["active_remote_logins"] = len(self.remote_sessions)
return state
@property
def _user_manager(self) -> UserManager:
"""
Returns the UserManager instance.
:return: The UserManager instance.
"""
return self.software_manager.software["UserManager"] # noqa
def pre_timestep(self, timestep: int) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.current_timestep = timestep
if self.local_session:
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
self._timeout_session(self.local_session)
def _timeout_session(self, session: UserSession) -> None:
"""
Handles session timeout logic.
:param session: The session to be timed out.
"""
session.end_step = self.current_timestep
session_identity = session.user.username
if session.local:
self.local_session = None
session_type = "Local"
else:
self.remote_sessions.pop(session.uuid)
session_type = "Remote"
session_identity = f"{session_identity} {session.remote_ip_address}"
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")
@property
def remote_session_limit_reached(self) -> bool:
"""
Checks if the maximum number of remote sessions has been reached.
:return: True if the limit is reached, otherwise False.
"""
return len(self.remote_sessions) >= self.max_remote_sessions
def validate_remote_session_uuid(self, remote_session_id: str) -> bool:
"""
Validates if a given remote session ID exists.
:param remote_session_id: The remote session ID to validate.
:return: True if the session ID exists, otherwise False.
"""
return remote_session_id in self.remote_sessions
def _login(
self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None
) -> Optional[str]:
"""
Logs a user in either locally or remotely.
:param username: The username of the account.
:param password: The password of the account.
:param local: Whether the login is local or remote.
:param remote_ip_address: The remote IP address for remote login.
:return: The session ID if login is successful, otherwise None.
"""
if not self._can_perform_action():
return None
user = self._user_manager.authenticate_user(username=username, password=password)
if not user:
self.sys_log.info(f"{self.name}: Incorrect username or password")
return None
session_id = None
if local:
create_new_session = True
if self.local_session:
if self.local_session.user != user:
# logout the current user
self.local_logout()
else:
# not required as existing logged-in user attempting to re-login
create_new_session = False
if create_new_session:
self.local_session = UserSession.create(user=user, timestep=self.current_timestep)
session_id = self.local_session.uuid
else:
if not self.remote_session_limit_reached:
remote_session = RemoteUserSession.create(
user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address
)
session_id = remote_session.uuid
self.remote_sessions[session_id] = remote_session
self.sys_log.info(f"{self.name}: User {user.username} logged in")
return session_id
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Logs a user in locally.
:param username: The username of the account.
:param password: The password of the account.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=True)
@validate_call()
def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]:
"""
Logs a user in remotely.
:param username: The username of the account.
:param password: The password of the account.
:param remote_ip_address: The remote IP address for the remote login.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address)
def _logout(self, local: bool = True, remote_session_id: Optional[str] = None) -> bool:
"""
Logs a user out either locally or remotely.
:param local: Whether the logout is local or remote.
:param remote_session_id: The remote session ID for remote logout.
:return: True if logout successful, otherwise False.
"""
if not self._can_perform_action():
return False
session = None
if local and self.local_session:
session = self.local_session
session.end_step = self.current_timestep
self.local_session = None
if not local and remote_session_id:
session = self.remote_sessions.pop(remote_session_id)
if session:
self.historic_sessions.append(session)
self.sys_log.info(f"{self.name}: User {session.user.username} logged out")
return True
return False
def local_logout(self) -> bool:
"""
Logs out the current local user.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=True)
def remote_logout(self, remote_session_id: str) -> bool:
"""
Logs out a remote user by session ID.
:param remote_session_id: The remote session ID.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=False, remote_session_id=remote_session_id)
@property
def local_user_logged_in(self) -> bool:
"""
Checks if a local user is currently logged in.
:return: True if a local user is logged in, otherwise False.
"""
return self.local_session is not None
class Node(SimComponent):
"""
A basic Node class that represents a node on the network.
@@ -861,11 +1500,14 @@ class Node(SimComponent):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.
This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not
This method initialises the ARP cache, ICMP handler, session manager, and software manager if they are not
provided.
"""
if not kwargs.get("sys_log"):
@@ -885,9 +1527,40 @@ class Node(SimComponent):
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
@property
def user_manager(self) -> Optional[UserManager]:
"""The Nodes User Manager."""
return self.software_manager.software.get("UserManager") # noqa
@property
def user_session_manager(self) -> Optional[UserSessionManager]:
"""The Nodes User Session Manager."""
return self.software_manager.software.get("UserSessionManager") # noqa
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Attempt to log in to the node uas a local user.
This method attempts to authenticate a local user with the given username and password. If successful, it
returns a session token. If authentication fails, it returns None.
:param username: The username of the account attempting to log in.
:param password: The password of the account attempting to log in.
:return: A session token if the login is successful, otherwise None.
"""
return self.user_session_manager.local_login(username, password)
def local_logout(self) -> None:
"""
Log out the current local user from the node.
This method ends the current local user's session and invalidates the session token.
"""
return self.user_session_manager.local_logout()
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
"""
@@ -942,7 +1615,7 @@ class Node(SimComponent):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -984,7 +1657,7 @@ class Node(SimComponent):
application_name = request[0]
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
return RequestResponse.from_bool(False)
return RequestResponse(status="success", data={"reason": "already installed"})
application_class = Application._application_registry[application_name]
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
@@ -1091,10 +1764,6 @@ class Node(SimComponent):
return rm
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1173,7 +1842,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
]
)
print(table)
@@ -1483,6 +2152,11 @@ class Node(SimComponent):
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services

View File

@@ -5,7 +5,13 @@ from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
Node,
UserManager,
UserSessionManager,
)
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import ApplicationOperatingState
@@ -306,6 +312,8 @@ class HostNode(Node):
"NTPClient": NTPClient,
"WebBrowser": WebBrowser,
"NMAP": NMAP,
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
}
"""List of system software that is automatically installed on nodes."""
@@ -338,18 +346,6 @@ class HostNode(Node):
"""
return self.software_manager.software.get("ARP")
def _install_system_software(self):
"""
Installs the system software and network services typically found on an operating system.
This method equips the host with essential network services and applications, preparing it for various
network-related tasks and operations.
"""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
super()._install_system_software()
def default_gateway_hello(self):
"""
Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address.

View File

@@ -4,14 +4,14 @@ from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.protocols.arp import ARPPacket
@@ -1200,6 +1200,11 @@ class Router(NetworkNode):
RouteTable, RouterARP, and RouterICMP services.
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
@@ -1235,6 +1240,7 @@ class Router(NetworkNode):
resolution within the network. These services are crucial for the router's operation, enabling it to manage
network traffic efficiently.
"""
super()._install_system_software()
self.software_manager.install(RouterICMP)
icmp: RouterICMP = self.software_manager.icmp # noqa
icmp.router = self

View File

@@ -108,6 +108,9 @@ class Switch(NetworkNode):
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.

View File

@@ -1,48 +1,25 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Final, List
from typing import List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
from pydantic import BaseModel, ConfigDict
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
class NMNEConfig(BaseModel):
"""Store all the information to perform NMNE operations."""
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
model_config = ConfigDict(extra="forbid")
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
capture_nmne: bool = False
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
nmne_capture_keywords: List[str] = []
"""List of keywords to identify malicious network events."""
capture_by_direction: bool = True
"""Captures should be organized by traffic direction (inbound/outbound)."""
capture_by_ip_address: bool = False
"""Captures should be organized by source or destination IP address."""
capture_by_protocol: bool = False
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
capture_by_port: bool = False
"""Captures should be organized by source or destination port."""
capture_by_keyword: bool = False
"""Captures should be filtered and categorised based on specific keywords."""

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
from pydantic_core.core_schema import FieldValidationInfo
from pydantic_core.core_schema import ValidationInfo
from primaite import getLogger
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
@field_validator("icmp_code") # noqa
@classmethod
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):

View File

@@ -103,7 +103,7 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftware]):
def install(self, software_class: Type[IOSoftware], **install_kwargs):
"""
Install an Application or Service.
@@ -113,7 +113,11 @@ class SoftwareManager:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
software.parent = self.node
if isinstance(software, Application):

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -291,7 +291,7 @@ class IOSoftware(Software):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not online."
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
)
return False
return True