From 638e0af1f3bd0c9e02bcbe472b66ae76d8738b9c Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 20 Feb 2026 12:33:50 +0100 Subject: [PATCH 1/4] fix --- .../falcon_mamba/modeling_falcon_mamba.py | 20 ++++++------------- .../models/mamba/modeling_mamba.py | 16 ++++++--------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 7ceec731c45e..4e8f6762822b 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -220,22 +220,14 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): global causal_conv1d, causal_conv1d_update, causal_conv1d_fn causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) - ) + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + global falcon_mamba_ssm, selective_state_update, selective_scan_fn, falcon_mamba_inner_fn falcon_mamba_ssm = lazy_load_kernel("falcon_mamba-ssm") - selective_state_update, selective_scan_fn, falcon_mamba_inner_fn = ( - ( - falcon_mamba_ssm.selective_state_update, - falcon_mamba_ssm.selective_scan_fn, - falcon_mamba_ssm.falcon_mamba_inner_fn, - ) - if falcon_mamba_ssm is not None - else (None, None, None) - ) + selective_state_update = getattr(falcon_mamba_ssm, "selective_state_update", None) + selective_scan_fn = getattr(falcon_mamba_ssm, "selective_scan_fn", None) + falcon_mamba_inner_fn = getattr(falcon_mamba_ssm, "falcon_mamba_inner_fn", None) self.warn_slow_implementation() diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index e80059a0b6cb..1b2dbc932b9c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -199,18 +199,14 @@ def __init__(self, config: MambaConfig, layer_idx: int): global causal_conv1d, causal_conv1d_update, causal_conv1d_fn causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update, causal_conv1d_fn = ( - (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) - if causal_conv1d is not None - else (None, None) - ) + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update, selective_scan_fn, mamba_inner_fn = ( - (mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn) - if mamba_ssm is not None - else (None, None, None) - ) + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) self.warn_slow_implementation() From ea418f5c16b33af37e0bfd303f8c806a02ae3e72 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 20 Feb 2026 16:25:16 +0100 Subject: [PATCH 2/4] fix all imports --- .../models/bamba/modeling_bamba.py | 13 ++++++-- .../models/bamba/modular_bamba.py | 13 ++++++-- .../falcon_mamba/modeling_falcon_mamba.py | 10 +++++-- .../modeling_granitemoehybrid.py | 13 ++++++-- .../models/jamba/modeling_jamba.py | 7 +++-- .../models/jamba/modular_jamba.py | 7 +++-- .../models/mamba/modeling_mamba.py | 6 ++-- .../models/mamba2/modeling_mamba2.py | 13 ++++++-- src/transformers/utils/import_utils.py | 30 +++++++++++++++++++ 9 files changed, 92 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 6511eb5bca0e..a91395460f3f 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -44,6 +44,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast +from ...utils.import_utils import resolve_internal_import from .configuration_bamba import BambaConfig @@ -561,9 +562,15 @@ def __init__(self, config: BambaConfig, layer_idx: int): global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) - mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) global is_fast_path_available is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 477e940752df..4ca828173d19 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -48,6 +48,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.import_utils import resolve_internal_import from .configuration_bamba import BambaConfig @@ -260,9 +261,15 @@ def __init__(self, config: BambaConfig, layer_idx: int): global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) - mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) global is_fast_path_available is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 4e8f6762822b..e44174b23839 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -34,7 +34,11 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging -from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling +from ...utils.import_utils import ( + is_mambapy_available, + is_torchdynamo_compiling, + resolve_internal_import, +) from .configuration_falcon_mamba import FalconMambaConfig @@ -225,7 +229,9 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): global falcon_mamba_ssm, selective_state_update, selective_scan_fn, falcon_mamba_inner_fn falcon_mamba_ssm = lazy_load_kernel("falcon_mamba-ssm") - selective_state_update = getattr(falcon_mamba_ssm, "selective_state_update", None) + selective_state_update = resolve_internal_import( + falcon_mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) selective_scan_fn = getattr(falcon_mamba_ssm, "selective_scan_fn", None) falcon_mamba_inner_fn = getattr(falcon_mamba_ssm, "falcon_mamba_inner_fn", None) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 91cc7c6b9610..298eb06f1d9c 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -40,6 +40,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import from ...utils.output_capturing import capture_outputs from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -438,9 +439,15 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) - mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) global is_fast_path_available is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9c5b513cf3f3..c3a72cd916fd 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -45,6 +45,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import from ...utils.output_capturing import OutputRecorder, capture_outputs from .configuration_jamba import JambaConfig @@ -354,9 +355,11 @@ def __init__(self, config: JambaConfig, layer_idx): global selective_state_update, mamba_inner_fn, selective_scan_fn mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) global is_fast_path_available is_fast_path_available = all( diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index a0bb7c18d453..3959dac3a08b 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -32,6 +32,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import from ...utils.output_capturing import OutputRecorder, capture_outputs from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward from ..mistral.modeling_mistral import MistralMLP @@ -247,9 +248,11 @@ def __init__(self, config: JambaConfig, layer_idx): global selective_state_update, mamba_inner_fn, selective_scan_fn mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) global is_fast_path_available is_fast_path_available = all( diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1b2dbc932b9c..512543928c7f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -33,7 +33,7 @@ auto_docstring, logging, ) -from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling +from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling, resolve_internal_import from .configuration_mamba import MambaConfig @@ -204,7 +204,9 @@ def __init__(self, config: MambaConfig, layer_idx: int): global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 6376ea1b39f9..441a29fc0213 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -26,6 +26,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, logging +from ...utils.import_utils import resolve_internal_import from .configuration_mamba2 import Mamba2Config @@ -265,9 +266,15 @@ def __init__(self, config: Mamba2Config, layer_idx: int): global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = getattr(mamba_ssm, "selective_state_update", None) - mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) - mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) global is_fast_path_available is_fast_path_available = all( diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1e527203e917..a98b3f49dcb5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -77,6 +77,36 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ return package_exists +def resolve_internal_import(module: ModuleType | None, chained_path: str) -> Callable | ModuleType | None: + """ + Check if a given `module` has an internal import path as defined by the `chained_path`. + This can either be the full path (not exposed in `__init__`) OR the last part of the chain (exposed in `__init__`). + + This is an important helper function for kernels based modules to apply the import from the module + itself, i.e. stay compatible with original libraries in certain cases. + + Example: + Module: `mamba_ssm` + Chained Path: `ops.triton.selective_state_update.selective_state_update` + Resulting import attempt at: + - `mamba_ssm.selective_state_update` + - `mamba_ssm.ops.triton.selective_state_update.selective_state_update` + """ + if not module: + return None + + if final_module := getattr(module, chained_path.split(".")[-1], None): + return final_module + + final_module = module + for path in chained_path.split("."): + final_module = getattr(final_module, path, None) + if not final_module: + return None + + return final_module + + def is_env_variable_true(env_variable: str) -> bool: """Detect whether `env_variable` has been set to a true value in the environment""" return os.getenv(env_variable, "false").lower() in ("true", "1", "y", "yes", "on") From 7f4d20b13302e272e025bb8599624f231461d033 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 20 Feb 2026 16:34:03 +0100 Subject: [PATCH 3/4] falcon h1 --- .../models/falcon_h1/modeling_falcon_h1.py | 38 +++++++++++-------- .../models/falcon_h1/modular_falcon_h1.py | 37 ++++++++++-------- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 94af79bab2d6..34f3c069cf03 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -36,6 +36,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -45,22 +46,10 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from ...utils.import_utils import resolve_internal_import from .configuration_falcon_h1 import FalconH1Config -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - - logger = logging.get_logger(__name__) @@ -533,9 +522,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer class FalconH1Mixer(nn.Module): """ @@ -610,6 +596,26 @@ def __init__(self, config: FalconH1Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index f87bd6b8ce57..aa37ba98962a 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -46,6 +46,7 @@ from ... import initialization as init from ...cache_utils import Cache +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -53,24 +54,10 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from ...utils.import_utils import resolve_internal_import from .configuration_falcon_h1 import FalconH1Config -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - logger = logging.get_logger(__name__) @@ -360,6 +347,26 @@ def __init__(self, config: FalconH1Config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) + + global is_fast_path_available + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" From e6f3c2cc0f4e5c850a46ab1fd43b5110811da5c8 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 20 Feb 2026 16:47:03 +0100 Subject: [PATCH 4/4] fix check conditions and fixup zamba --- .../models/bamba/modeling_bamba.py | 10 +++- .../models/bamba/modular_bamba.py | 10 +++- .../models/falcon_h1/modeling_falcon_h1.py | 10 +++- .../models/falcon_h1/modular_falcon_h1.py | 10 +++- .../modeling_granitemoehybrid.py | 10 +++- .../models/zamba/modeling_zamba.py | 41 +++++++++------- .../models/zamba2/modeling_zamba2.py | 45 +++++++++++------ .../models/zamba2/modular_zamba2.py | 48 ++++++++++++------- 8 files changed, 129 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index a91395460f3f..489eafded406 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -573,7 +573,15 @@ def __init__(self, config: BambaConfig, layer_idx: int): ) global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) if not is_fast_path_available: logger.warning_once( diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 4ca828173d19..da6223044e23 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -272,7 +272,15 @@ def __init__(self, config: BambaConfig, layer_idx: int): ) global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) if not is_fast_path_available: logger.warning_once( diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 34f3c069cf03..3bdd8427da30 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -614,7 +614,15 @@ def __init__(self, config: FalconH1Config, layer_idx: int): ) global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) if not is_fast_path_available: logger.warning_once( diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index aa37ba98962a..5b6e40194c4c 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -365,7 +365,15 @@ def __init__(self, config: FalconH1Config, layer_idx: int): ) global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) if not is_fast_path_available: logger.warning_once( diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 298eb06f1d9c..2e1625742cce 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -450,7 +450,15 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): ) global is_fast_path_available - is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) if not is_fast_path_available: logger.warning_once( diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 3ed8531c13e0..2a59ad3901ec 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -37,26 +38,10 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available +from ...utils.import_utils import resolve_internal_import from .configuration_zamba import ZambaConfig -if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) - - logger = logging.get_logger(__name__) @@ -358,6 +343,24 @@ def __init__(self, config: ZambaConfig, layer_idx): self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.mamba_head_dim)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + global causal_conv1d, causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) + + global is_fast_path_available + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" @@ -556,6 +559,10 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non return contextualized_states def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if self.use_fast_kernels: if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type: raise ValueError( diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 99ff7d260756..05f9b9090933 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -32,6 +32,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_func_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -41,21 +42,10 @@ from ...processing_utils import Unpack from ...utils import auto_docstring, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available +from ...utils.import_utils import resolve_internal_import from .configuration_zamba2 import Zamba2Config -if is_mamba_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - logger = logging.get_logger(__name__) @@ -527,9 +517,6 @@ def segment_sum(input_tensor): return tensor_segsum -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - class Zamba2MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -594,6 +581,34 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) + + global is_fast_path_available + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b56ebbd3895e..6e557fdf3cc2 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,16 +21,14 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import is_torchdynamo_compiling, logging -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_mamba_ssm_available, -) +from ...utils.import_utils import resolve_internal_import from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( @@ -48,20 +46,6 @@ from .configuration_zamba2 import Zamba2Config -if is_mamba_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B" logger = logging.get_logger(__name__) @@ -347,6 +331,34 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + global causal_conv1d_update, causal_conv1d_fn + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" + ) + + global is_fast_path_available + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"