Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import resolve_internal_import
from .configuration_bamba import BambaConfig


Expand Down Expand Up @@ -561,12 +562,26 @@ def __init__(self, config: BambaConfig, layer_idx: int):

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.import_utils import resolve_internal_import
from .configuration_bamba import BambaConfig


Expand Down Expand Up @@ -260,12 +261,26 @@ def __init__(self, config: BambaConfig, layer_idx: int):

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
Expand Down
46 changes: 30 additions & 16 deletions src/transformers/models/falcon_h1/modeling_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand All @@ -45,22 +46,10 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...utils.import_utils import resolve_internal_import
from .configuration_falcon_h1 import FalconH1Config


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -533,9 +522,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
return hidden_states


is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
class FalconH1Mixer(nn.Module):
"""
Expand Down Expand Up @@ -610,6 +596,34 @@ def __init__(self, config: FalconH1Config, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
45 changes: 30 additions & 15 deletions src/transformers/models/falcon_h1/modular_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,18 @@

from ... import initialization as init
from ...cache_utils import Cache
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...utils.import_utils import resolve_internal_import
from .configuration_falcon_h1 import FalconH1Config


if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -360,6 +347,34 @@ def __init__(self, config: FalconH1Config, layer_idx: int):

self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
Expand Down
26 changes: 12 additions & 14 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
from ...utils.import_utils import (
is_mambapy_available,
is_torchdynamo_compiling,
resolve_internal_import,
)
from .configuration_falcon_mamba import FalconMambaConfig


Expand Down Expand Up @@ -220,22 +224,16 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):

global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global falcon_mamba_ssm, selective_state_update, selective_scan_fn, falcon_mamba_inner_fn
falcon_mamba_ssm = lazy_load_kernel("falcon_mamba-ssm")
selective_state_update, selective_scan_fn, falcon_mamba_inner_fn = (
(
falcon_mamba_ssm.selective_state_update,
falcon_mamba_ssm.selective_scan_fn,
falcon_mamba_ssm.falcon_mamba_inner_fn,
)
if falcon_mamba_ssm is not None
else (None, None, None)
selective_state_update = resolve_internal_import(
falcon_mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
selective_scan_fn = getattr(falcon_mamba_ssm, "selective_scan_fn", None)
falcon_mamba_inner_fn = getattr(falcon_mamba_ssm, "falcon_mamba_inner_fn", None)

self.warn_slow_implementation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import resolve_internal_import
from ...utils.output_capturing import capture_outputs
from .configuration_granitemoehybrid import GraniteMoeHybridConfig

Expand Down Expand Up @@ -438,12 +439,26 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)

if not is_fast_path_available:
logger.warning_once(
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.import_utils import resolve_internal_import
from ...utils.output_capturing import OutputRecorder, capture_outputs
from .configuration_jamba import JambaConfig

Expand Down Expand Up @@ -354,9 +355,11 @@ def __init__(self, config: JambaConfig, layer_idx):

global selective_state_update, mamba_inner_fn, selective_scan_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)

global is_fast_path_available
is_fast_path_available = all(
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/jamba/modular_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.import_utils import resolve_internal_import
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward
from ..mistral.modeling_mistral import MistralMLP
Expand Down Expand Up @@ -247,9 +248,11 @@ def __init__(self, config: JambaConfig, layer_idx):

global selective_state_update, mamba_inner_fn, selective_scan_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)

global is_fast_path_available
is_fast_path_available = all(
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
auto_docstring,
logging,
)
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling, resolve_internal_import
from .configuration_mamba import MambaConfig


Expand Down Expand Up @@ -199,18 +199,16 @@ def __init__(self, config: MambaConfig, layer_idx: int):

global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
if mamba_ssm is not None
else (None, None, None)
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)

self.warn_slow_implementation()

Expand Down
Loading