diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu index 93a838b09f2..c47a393b2fc 100644 --- a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu @@ -312,6 +312,14 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_ base, position_ids, num_tokens, factor, low, high, attention_factor); }); break; + case 256: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel<256, INTERLEAVE><<>>( + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), + base, position_ids, num_tokens, factor, low, high, attention_factor); + }); + break; default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim); } } diff --git a/examples/models/core/qwen/README.md b/examples/models/core/qwen/README.md index 65f66e66402..cd216a15e29 100644 --- a/examples/models/core/qwen/README.md +++ b/examples/models/core/qwen/README.md @@ -21,13 +21,14 @@ This document shows how to build and run a [Qwen](https://huggingface.co/Qwen) m - [Quick start](#quick-start) - [Run a single inference](#run-a-single-inference) - [Evaluation](#evaluation) - - [Model Quantization to FP4](#model-quantization-to-fp4) + - [Model Quantization](#model-quantization) - [Benchmark](#benchmark) - [Serving](#serving) - [trtllm-serve](#trtllm-serve) - [Disaggregated Serving](#disaggregated-serving) - [Eagle3](#eagle3) - - [Dynamo](#dynamo) + - [Dynamo](#dynamo) + - [Qwen3-Next](#qwen3-next) - [Notes and Troubleshooting](#notes-and-troubleshooting) - [Credits](#credits) @@ -926,6 +927,15 @@ For further details, please refer to [speculative-decoding.md](../../../../docs/ NVIDIA Dynamo is a high-throughput low-latency inference framework designed for serving generative AI and reasoning models in multi-node distributed environments. Dynamo supports TensorRT LLM as one of its inference engine. For details on how to use TensorRT LLM with Dynamo please refer to [LLM Deployment Examples using TensorRT-LLM](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md) +## Qwen3-Next + +Below is the command to run the Qwen3-Next model. + +```bash +mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py --model_dir /Qwen3-Next-80B-A3B-Thinking --kv_cache_fraction 0.6 --disable_kv_cache_reuse --max_batch_size 1 --tp_size 4 + +``` + ## Notes and Troubleshooting - **Model Directory:** Update `` with the actual path where the model weights reside. diff --git a/tensorrt_llm/_torch/custom_ops/__init__.py b/tensorrt_llm/_torch/custom_ops/__init__.py index 2c59f627062..5fb2927b014 100644 --- a/tensorrt_llm/_torch/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/custom_ops/__init__.py @@ -22,13 +22,15 @@ if IS_FLASHINFER_AVAILABLE: from .flashinfer_custom_ops import ( flashinfer_apply_rope_with_cos_sin_cache_inplace, - flashinfer_fused_add_rmsnorm, flashinfer_rmsnorm, - flashinfer_silu_and_mul) + flashinfer_fused_add_rmsnorm, flashinfer_gemma_fused_add_rmsnorm, + flashinfer_gemma_rmsnorm, flashinfer_rmsnorm, flashinfer_silu_and_mul) __all__ += [ 'flashinfer_silu_and_mul', 'flashinfer_rmsnorm', 'flashinfer_fused_add_rmsnorm', 'flashinfer_apply_rope_with_cos_sin_cache_inplace', + 'flashinfer_gemma_fused_add_rmsnorm', + 'flashinfer_gemma_rmsnorm', ] if IS_CUTLASS_DSL_AVAILABLE: diff --git a/tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py b/tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py index 223c9bd5b04..2130b4e255c 100644 --- a/tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py @@ -4,7 +4,8 @@ if IS_FLASHINFER_AVAILABLE: from flashinfer.activation import silu_and_mul - from flashinfer.norm import fused_add_rmsnorm, rmsnorm + from flashinfer.norm import (fused_add_rmsnorm, gemma_fused_add_rmsnorm, + gemma_rmsnorm, rmsnorm) from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace # Warp this into custom op since flashinfer didn't warp it properly and we want to avoid graph break between mlp layer for user buffer optimization @@ -27,6 +28,17 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: return torch.empty_like(input) + @torch.library.custom_op("trtllm::flashinfer_gemma_rmsnorm", + mutates_args=()) + def flashinfer_gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, + eps: float) -> torch.Tensor: + return gemma_rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL) + + @flashinfer_gemma_rmsnorm.register_fake + def _(input: torch.Tensor, weight: torch.Tensor, + eps: float) -> torch.Tensor: + return torch.empty_like(input) + @torch.library.custom_op("trtllm::flashinfer_fused_add_rmsnorm", mutates_args=("input", "residual")) def flashinfer_fused_add_rmsnorm(input: torch.Tensor, @@ -34,6 +46,18 @@ def flashinfer_fused_add_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> None: fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL) + @torch.library.custom_op("trtllm::flashinfer_gemma_fused_add_rmsnorm", + mutates_args=("input", "residual")) + def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float) -> None: + gemma_fused_add_rmsnorm(input, + residual, + weight, + eps, + enable_pdl=ENABLE_PDL) + @torch.library.custom_op( "trtllm::flashinfer_apply_rope_with_cos_sin_cache_inplace", mutates_args=("query", "key")) diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 56421269228..263f0e162c8 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -26,6 +26,7 @@ from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel from .modeling_qwen3 import Qwen3ForCausalLM from .modeling_qwen3_moe import Qwen3MoeForCausalLM +from .modeling_qwen3_next import Qwen3NextForCausalLM from .modeling_qwen_moe import Qwen2MoeForCausalLM from .modeling_seedoss import SeedOssForCausalLM from .modeling_siglip import SiglipVisionModel @@ -66,6 +67,7 @@ "Qwen2_5_VLModel", "Qwen3ForCausalLM", "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", "GptOssForCausalLM", "SeedOssForCausalLM", ] diff --git a/tensorrt_llm/_torch/models/checkpoints/__init__.py b/tensorrt_llm/_torch/models/checkpoints/__init__.py index 6718e574983..716e4f89ddb 100644 --- a/tensorrt_llm/_torch/models/checkpoints/__init__.py +++ b/tensorrt_llm/_torch/models/checkpoints/__init__.py @@ -8,6 +8,7 @@ from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper from .hf.qwen2vl_weight_mapper import Qwen2VLHfWeightMapper from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper +from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper from .hf.weight_loader import HfWeightLoader from .hf.weight_mapper import HfWeightMapper @@ -15,5 +16,6 @@ "HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper", "Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper", - "Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper" + "Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper", + "Qwen3NextHfWeightMapper" ] diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py new file mode 100644 index 00000000000..e8df2da4156 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py @@ -0,0 +1,105 @@ +from typing import Union + +import torch +from torch import nn + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ + Qwen2MoeHfWeightMapper +from tensorrt_llm._torch.models.modeling_nemotron_h import split +from tensorrt_llm._torch.models.modeling_utils import register_mapper +from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM + + +@register_mapper("HF", "Qwen3NextForCausalLM") +class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper): + + def init_model_and_config(self, model: Union[nn.Module, + DecoderModelForCausalLM], + config: ModelConfig): + super().init_model_and_config(model, config) + self._num_kv_heads = model.config.num_key_value_heads if hasattr( + model.config, 'num_key_value_heads' + ) and model.config.num_key_value_heads is not None else model.config.num_attention_heads + + def should_skip_module(self, module_name: str) -> bool: + if module_name.startswith("draft_model"): + return True + return super().should_skip_module(module_name) + + def _duplicate_kv_weights(self, module: nn.Module, new_name: str, + weights: dict): + tensors_to_duplicate = ["weight", "bias"] + if module.quant_config.quant_mode.has_nvfp4(): + tensors_to_duplicate.append("weight_scale") + if module.quant_config.quant_mode.has_fp8_block_scales(): + tensors_to_duplicate.append("weight_scale_inv") + + if new_name in ['k_proj', 'v_proj']: + num_kv_heads_list = [self._num_kv_heads + ] * len(weights) if isinstance( + self._num_kv_heads, + int) else self._num_kv_heads + processed_weights = { + k: + self._duplicate_kv(weight=v[:], + num_kv_heads=num_kv_heads_list[i], + tensor_parallel_size=self._tp_size) + if k in tensors_to_duplicate else v + for i, (k, v) in enumerate(weights.items()) + } + return processed_weights + + return weights + + def preprocess_weights(self, weights: dict) -> dict: + config = self.config.pretrained_config + tp_size = self.config.mapping.tp_size + tp_rank = self.config.mapping.tp_rank + + # linear_num_value_heads = config.linear_num_value_heads + # linear_num_key_heads = config.linear_num_key_heads + # linear_key_head_dim = config.linear_key_head_dim + # linear_value_head_dim = config.linear_value_head_dim + linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128 + linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128 + + new_weights = {} + for name, _ in weights.items(): + key = name + + if "A_log" in key: + w = split(weights[name], tp_size, tp_rank) + w = w.to(torch.float32) + new_weights[key] = w + elif "dt_bias" in key: + w = split(weights[name], tp_size, tp_rank) + w = w.to(torch.float32) + new_weights[key] = w + elif "in_proj" in key: + # Don't need to split in_proj weight based on the implementation of reference. + # Need to know the reason. + new_weights[key] = weights[name] + elif "conv1d" in key: + w = weights[name] + # removing dim(1) because we are using Linear to store conv1d weights + if "weight" in key: + w = w.squeeze(1) + + conv_q, conv_k, conv_v = torch.split( + w, [linear_key_dim, linear_key_dim, linear_value_dim], + dim=0) + + w = [] + for rank in range(tp_size): + conv_q_rank = split(conv_q, tp_size, rank) + conv_k_rank = split(conv_k, tp_size, rank) + conv_v_rank = split(conv_v, tp_size, rank) + y = torch.concat([conv_q_rank, conv_k_rank, conv_v_rank]) + w.append(y) + w = torch.concat(w).contiguous() + new_weights[key] = w + else: + new_weights[key] = weights[name] + + return new_weights diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index d93423a45c4..730c4071b0a 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -32,8 +32,12 @@ def __init__( model_config: ModelConfig[Qwen3Config], layer_idx: Optional[int] = None, fuse_qk_norm_rope: bool = True, + attn_output_gate: bool = False, + use_gemma_rms_norm: bool = False, ): config = model_config.pretrained_config + self.pretrained_config = config + self.attn_output_gate = attn_output_gate if getattr(config, "rope_scaling", None) is not None: if "type" in config.rope_scaling: @@ -58,13 +62,15 @@ def __init__( num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, - bias=config.attention_bias, + bias=getattr(config, "attention_bias", None), pos_embd_params=pos_embd_params, fuse_qk_norm_rope=fuse_qk_norm_rope, layer_idx=layer_idx, dtype=config.torch_dtype, - dense_bias=config.attention_bias, + dense_bias=getattr(config, "attention_bias", None), config=model_config, + attn_output_gate=self.attn_output_gate, + use_gemma_rms_norm=use_gemma_rms_norm, ) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py new file mode 100644 index 00000000000..bf68c340feb --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -0,0 +1,1519 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/qwen3_next.py +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from torch import nn +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ + BaseWeightMapper +from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule +from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \ + fused_sigmoid_gating_delta_rule_update +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm.mapping import Mapping + +from ..attention_backend import AttentionMetadata +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + MoEAllReduce, MoEAllReduceParams, allgather) +from ..model_config import ModelConfig +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding +from ..modules.fused_moe import (BaseMoeRoutingMethod, + RenormalizeMoeRoutingMethod, + RenormalizeNaiveMoeRoutingMethod, + RoutingMethodType, TRTLLMGenFusedMoE, + create_moe) +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import Linear, TensorParallelMode +from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated +from ..modules.rms_norm import RMSNorm +from ..speculative import SpecMetadata +from ..utils import AuxStreamType +from .modeling_qwen3 import Qwen3Attention +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=[], + layer_types=None, + **kwargs, + ): + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + interval_pattern = kwargs.get("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool( + (i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + # layer_type_validation(self.layer_types, self.num_hidden_layers) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +AutoConfig.register("qwen3_next", Qwen3NextConfig) + + +class Qwen3NextGate(nn.Module): + + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int, + dtype: Optional[torch.dtype] = None, + apply_routing: bool = False, + routing_method_type: RoutingMethodType = RoutingMethodType.Renormalize, + moe_backend: str = "CUTLASS", + ): + super().__init__() + self.top_k = top_k + self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), + dtype=dtype), + requires_grad=False) + self.routing_method_type = routing_method_type + # FIXME: out_dtype=float32 does not work + # self.out_dtype = torch.float32 if moe_backend == "TRTLLM" else dtype + self.out_dtype = dtype + + assert not apply_routing, "Qwen3NextGate routing is called inside MoE" + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits: torch.Tensor = torch.ops.trtllm.cublas_mm( + hidden_states, self.weight.t(), bias=None, out_dtype=self.out_dtype) + return logits + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + + self.weight.copy_(weights[0]["weight"][:]) + + @property + def routing_method(self) -> BaseMoeRoutingMethod: + if self.routing_method_type == RoutingMethodType.RenormalizeNaive: + return RenormalizeNaiveMoeRoutingMethod(top_k=self.top_k) + elif self.routing_method_type == RoutingMethodType.Renormalize: + return RenormalizeMoeRoutingMethod(top_k=self.top_k) + else: + raise ValueError( + f"Unsupported routing method: {self.routing_method_type}") + + +class Qwen3NextSparseMoeBlock(nn.Module): + + def __init__( + self, + model_config: ModelConfig[Qwen3NextConfig], + aux_stream: torch.cuda.Stream, + layer_idx: Optional[int] = None, + ): + super().__init__() + config = model_config.pretrained_config + self.model_config = model_config + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.enable_attention_dp = model_config.mapping.enable_attention_dp + self.mapping = model_config.mapping + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + + self.gate = Qwen3NextGate( + hidden_size=self.hidden_dim, + num_experts=self.num_experts, + top_k=self.top_k, + dtype=config.torch_dtype, + apply_routing=False, + routing_method_type=RoutingMethodType.Renormalize, + moe_backend=model_config.moe_backend, + ) + + self.experts = create_moe( + num_experts=self.num_experts, + routing_method=self.gate.routing_method, + hidden_size=self.hidden_dim, + intermediate_size=self.moe_intermediate_size, + aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, + dtype=config.torch_dtype, + reduce_results=False, + model_config=model_config, + layer_idx=layer_idx, + ) + + self.shared_expert = GatedMLP( + hidden_size=self.hidden_dim, + intermediate_size=config.shared_expert_intermediate_size, + bias=config.mlp_bias if hasattr(config, 'mlp_bias') else False, + dtype=config.torch_dtype, + config=model_config, + reduce_output=False, + ) + + self.shared_expert_gate = Linear(self.hidden_dim, + 1, + bias=False, + dtype=config.torch_dtype, + quant_config=None) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + all_reduce_params: Optional[AllReduceParams] = None, + do_finalize: Optional[bool] = True, + ) -> torch.Tensor: + assert hidden_states.shape[-1] == self.hidden_dim + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_dim) + use_dp_padding = False + all_rank_num_tokens = attn_metadata.all_rank_num_tokens + + if not do_finalize: + # TODO: support do_finalize == False + raise NotImplementedError( + "do_finalize == False is not supported yet") + + if self.enable_attention_dp and self.mapping.tp_size > 1: + if isinstance(self.experts, TRTLLMGenFusedMoE): + hidden_states = allgather(hidden_states, + self.mapping, + dim=0, + sizes=all_rank_num_tokens) + + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + do_finalize=do_finalize, + ) + + if not do_finalize: + return final_hidden_states + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + + if not self.enable_attention_dp and self.mapping.tp_size > 1: + final_hidden_states = self.allreduce( + final_hidden_states, all_reduce_params=all_reduce_params) + + return final_hidden_states.view(orig_shape) + + +@triton.jit +def fused_qkvzba_split_reshape_cat_kernel( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + NUM_HEADS_QK: tl.constexpr, + NUM_HEADS_V: tl.constexpr, + HEAD_QK: tl.constexpr, + HEAD_V: tl.constexpr, +): + i_bs, i_qk = tl.program_id(0), tl.program_id(1) + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 + BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 + QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + q_end: tl.constexpr = HEAD_QK + blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(0, q_end)) + k_end: tl.constexpr = q_end + HEAD_QK + blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)) + v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)) + z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)) + blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + i_qk * HEAD_QK + tl.arange(0, HEAD_QK)) + blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK)) + blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) + blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) + tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) + tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) + tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) + tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) + b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK + for i in tl.static_range(b_end): + blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i + tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) + for i in tl.static_range(b_end, a_end): + blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_a_st_ptr = (a + i_bs * NUM_HEADS_V + + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)) + tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + + +def fused_qkvzba_split_reshape_cat( + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, +): + batch, seq_len = mixed_qkvz.shape[0], 1 + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v + mixed_qkv = torch.empty( + [batch * seq_len, qkv_dim_t], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + z = torch.empty( + [batch * seq_len, num_heads_v, head_v], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + b = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + a = torch.empty_like(b) + grid = (batch * seq_len, num_heads_qk) + fused_qkvzba_split_reshape_cat_kernel[grid]( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, + num_warps=1, + num_stages=3, + ) + return mixed_qkv, z, b, a + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g + + +class Qwen3NextGatedDeltaNet(nn.Module): + + def __init__( + self, + model_config: ModelConfig[Qwen3NextConfig], + layer_idx: Optional[int] = None, + ): + super().__init__() + config = model_config.pretrained_config + self.model_config = model_config + self.pretrained_config = config + + # tensor parallel + tp_size = model_config.mapping.tp_size + pp_size = model_config.mapping.pp_size + if model_config.mapping.enable_attention_dp: + tp_size = 1 + + mapping = Mapping( + world_size=tp_size * pp_size, + tp_size=tp_size, + pp_size=pp_size, + rank=model_config.mapping.rank, + gpus_per_node=model_config.mapping.gpus_per_node, + enable_attention_dp=model_config.mapping.enable_attention_dp, + ) + self.mapping = mapping + + self.attn_tp_rank = mapping.tp_rank + self.attn_tp_size = mapping.tp_size + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = Linear( + self.conv_kernel_size, + self.conv_dim, + bias=False, + dtype=config.torch_dtype, + mapping=mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=model_config.get_quant_config(), + reduce_output=False, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + allreduce_strategy=model_config.allreduce_strategy, + force_dynamic_quantization=model_config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=False) + + self.in_proj_qkvz = Linear( + self.hidden_size, + self.key_dim * 2 + self.value_dim * 2, + bias=False, + dtype=config.torch_dtype, + mapping=mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=model_config.get_quant_config(), + reduce_output=False, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + allreduce_strategy=model_config.allreduce_strategy, + force_dynamic_quantization=model_config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=False) + self.in_proj_ba = Linear( + self.hidden_size, + self.num_v_heads * 2, + bias=False, + dtype=config.torch_dtype, + mapping=mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=model_config.get_quant_config(), + reduce_output=False, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + allreduce_strategy=model_config.allreduce_strategy, + force_dynamic_quantization=model_config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=False) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones( + (self.num_v_heads // self.attn_tp_size), + dtype=torch.float32, + ), + requires_grad=False, + ) + + A = torch.empty(divide(self.num_v_heads, self.attn_tp_size), + dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter( + torch.log(A), + requires_grad=False, + ) + self.A_log._no_weight_decay = True + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + + # gemmaNorm is not supported in fused_all_reduce kernel. + # So, we need to do allReduce in Linear and do gemmaNorm in separate kernel. + self.out_proj = Linear( + self.value_dim, + self.hidden_size, + bias=False, + dtype=config.torch_dtype, + mapping=mapping, + tensor_parallel_mode=TensorParallelMode.ROW, + quant_config=model_config.get_quant_config(), + reduce_output=True, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + allreduce_strategy=model_config.allreduce_strategy, + force_dynamic_quantization=model_config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=False) + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.attn_tp_size, + (self.head_k_dim + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // + self.num_k_heads), + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self.num_k_heads // self.attn_tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads, + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, + split_arg_list_qkvz, + dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size) + + return query, key, value, z, b, a + + def forward_decode( + self, + conv_states, + ssm_states, + num_decodes, + cu_seqlens, + **kwargs, + ): + mixed_qkv = kwargs["mixed_qkv"] + a = kwargs["a"] + b = kwargs["b"] + cache_indices = kwargs["cache_indices"] + + query_start_loc = torch.arange(0, + num_decodes + 1, + device=cu_seqlens.device).to(torch.long) + + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + self.conv1d.weight, + self.conv1d.bias, + self.activation, + conv_state_indices=cache_indices, + ) + + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.attn_tp_size, + self.key_dim // self.attn_tp_size, + self.value_dim // self.attn_tp_size, + ], + dim=-1, + ) + # Reshape from [l, h*d] to [1, l, h, d] + seq_len = query.shape[0] + num_heads = query.shape[1] // self.head_k_dim + query = query.view(1, seq_len, num_heads, self.head_k_dim) + key = key.view(1, seq_len, num_heads, self.head_k_dim) + value = value.view(1, seq_len, value.shape[1] // self.head_v_dim, + self.head_v_dim) + + core_attn_out = fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log, + dt_bias=self.dt_bias, + q=query, + k=key, + v=value, + a=a, + b=b, + initial_state_source=ssm_states, + initial_state_indices=cache_indices, + cu_seqlens=query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + return core_attn_out + + def forward_extend( + self, + conv_states, + ssm_states, + **kwargs, + ): + mixed_qkv = kwargs["mixed_qkv"] + a = kwargs["a"] + b = kwargs["b"] + batch_size = kwargs["batch_size"] + has_initial_states = kwargs["has_initial_states"][:batch_size] + cache_indices = kwargs["cache_indices"] + query_start_loc = kwargs["query_start_loc"][:batch_size + 1] + num_prefill_tokens = kwargs["num_prefill_tokens"] + num_decode_tokens = kwargs["num_decode_tokens"] + state_indices_p = kwargs["state_indices_p"] + state_indices_d = kwargs["state_indices_d"] + num_prefill = kwargs["num_prefill"] + num_decode = kwargs["num_decode"] + + conv_states_to_use = conv_states + + seqlen_split_size = [num_prefill_tokens, num_decode_tokens] + if num_decode_tokens > 0: + mixed_qkv_p, mixed_qkv_d = torch.split(mixed_qkv, + seqlen_split_size, + dim=0) + query_start_loc_p = query_start_loc[:num_prefill + 1] + has_initial_states_p = has_initial_states[:num_prefill] + + mixed_qkv_p = causal_conv1d_fn( + mixed_qkv_p.transpose(0, 1), + self.conv1d.weight, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_states_to_use, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_p, + query_start_loc=query_start_loc_p, + ).transpose(0, 1) + + mixed_qkv_d = causal_conv1d_update( + mixed_qkv_d, + conv_states_to_use, + self.conv1d.weight, + self.conv1d.bias, + activation=self.activation, + conv_state_indices=state_indices_d, + ) + mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0) + else: + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + self.conv1d.weight, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_states_to_use, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ).transpose(0, 1) + + key_split_dim = self.key_dim // self.attn_tp_size + value_split_dim = self.value_dim // self.attn_tp_size + + query, key, value = torch.split( + mixed_qkv, + [key_split_dim, key_split_dim, value_split_dim], + dim=-1, + ) + + actual_seq_len = query.shape[0] + num_heads = query.shape[1] // self.head_k_dim + num_value_heads = value.shape[1] // self.head_v_dim + + query = query.view(1, actual_seq_len, num_heads, self.head_k_dim) + key = key.view(1, actual_seq_len, num_heads, self.head_k_dim) + value = value.view(1, actual_seq_len, num_value_heads, self.head_v_dim) + + beta = b.sigmoid() + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + + g = g.unsqueeze(0) + beta = beta.unsqueeze(0) + + recurrent_state = ssm_states[cache_indices] + + if num_decode > 0: + # TODO set it in mambaCacheManager + decode_query_start_loc = torch.arange( + 1, num_decode + 1, + device=query_start_loc.device) # num_decode δΈͺ + decode_query_start_loc = decode_query_start_loc + query_start_loc[ + num_prefill] + new_query_start_loc = torch.cat( + [query_start_loc[:num_prefill + 1], decode_query_start_loc]) + else: + new_query_start_loc = query_start_loc + + new_query_start_loc = new_query_start_loc.to(torch.long) + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + cu_seqlens=new_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, + copy=False) + ssm_states[cache_indices] = last_recurrent_state + + return core_attn_out + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_metadata: Mamba2Metadata, + all_reduce_params: Optional[AllReduceParams] = None, + ): + ### sglang linear attn + # has_initial_states = None + # if forward_batch.extend_prefix_lens is not None: + # has_initial_states = forward_batch.extend_prefix_lens > 0 + + # # Set up dimensions for reshapes later + seq_len, _ = hidden_states.shape + conv_state, recurrent_state = None, None + + ### mamba2_mixer layer + # calculate split size + num_prefills = attn_metadata.num_contexts + num_decodes = attn_metadata.seq_lens.shape[0] - num_prefills + num_prefill_tokens = attn_metadata.num_ctx_tokens + num_decode_tokens = attn_metadata.num_tokens - num_prefill_tokens + batch_split_size = [num_prefills, num_decodes] + has_initial_states = mamba_metadata.has_initial_states + + state_indices = attn_metadata.kv_cache_manager.get_state_indices( + )[:num_prefills + num_decodes] + + state_indices_p, state_indices_d = torch.split(state_indices, + batch_split_size) + conv_states = attn_metadata.kv_cache_manager.get_conv_states( + self.layer_idx) + ssm_states = attn_metadata.kv_cache_manager.get_ssm_states( + self.layer_idx) + if num_prefills > 0: + ssm_states[state_indices_p] = 0 + # conv_states[state_indices_p] = 0 # not necessary + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + + if self.num_v_heads // self.num_k_heads in [1, 2, + 4]: # and is_cuda_graph: + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.attn_tp_size), + triton.cdiv(self.num_v_heads, self.attn_tp_size), + self.head_k_dim, + self.head_v_dim, + ) + else: + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: x.reshape(x.shape[0], -1), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + kwargs = { + "mixed_qkv": mixed_qkv, + "a": a, + "b": b, + "z": z, + "has_initial_states": has_initial_states, + "cache_indices": state_indices, + "query_start_loc": mamba_metadata.cu_seqlens, + "batch_size": attn_metadata.seq_lens.shape[0], + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "state_indices_p": state_indices_p, + "state_indices_d": state_indices_d, + "num_prefill": num_prefills, + "num_decode": num_decodes, + } + + new_implementation = True + if new_implementation: + if num_prefills > 0: + attn_out = self.forward_extend(conv_states, ssm_states, + **kwargs) + else: + attn_out = self.forward_decode(conv_states, ssm_states, + num_decodes, + mamba_metadata.cu_seqlens, + **kwargs) + + z_shape_og = z.shape + # reshape input data into 2D tensor + attn_out = attn_out.reshape(-1, attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + attn_out = self.norm(attn_out, z) + attn_out = attn_out.reshape(z_shape_og) + attn_out = attn_out.reshape(*attn_out.shape[:-2], -1) + + output = self.out_proj(attn_out, all_reduce_params=all_reduce_params) + return output + + +class Qwen3NextLinearDecoderLayer(nn.Module): + + def __init__( + self, + model_config: ModelConfig[Qwen3NextConfig], + layer_idx: int, + aux_stream: torch.cuda.Stream, + ): + super().__init__() + self.model_config = model_config + config = model_config.pretrained_config + self.linear_attn = Qwen3NextGatedDeltaNet(model_config, layer_idx) + + self.mapping = model_config.mapping + self.enable_attention_dp = self.mapping.enable_attention_dp + + self.mlp = Qwen3NextSparseMoeBlock(model_config, + aux_stream, + layer_idx=layer_idx) + + use_gemma_rms_norm = True + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + use_gemma_rms_norm=use_gemma_rms_norm) + + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + use_gemma_rms_norm=use_gemma_rms_norm) + self.layer_idx = layer_idx + + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + self.next_layer_layernorm: RMSNorm = None + + self.fusion_config = EagerFusionConfig() + ### TODO: enable eager_fusion by default + self.enable_fusion = os.environ.get( + "TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0" + self.enable_fusion &= not self.enable_attention_dp + + self.mapping.has_tp() + has_pp = self.mapping.has_pp() + + # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp + self.fusion_config.PRE_MOE_FUSION = False # the fusion kernel does not support gemmaNorm yet + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping) + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False + # Linear Attention + ### FIXME: 1. forward_batch; 2. allreduce + if hidden_states.shape[0] != 0: + hidden_states = self.linear_attn( + hidden_states, + attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION + or self.mapping.tp_size == 1)), + **kwargs) + if self.fusion_config.PRE_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION + or self.mapping.tp_size == 1), + )) + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now + do_finalize = not (hidden_states.shape[0] + <= self.moe_allreduce.max_token + and self.fusion_config.POST_MOE_FUSION + and self.model_config.moe_backend == 'TRTLLM' + and self.mlp.experts.has_nvfp4) + + hidden_states = self.mlp( + hidden_states, + attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + do_finalize=do_finalize, + ) + if self.fusion_config.POST_MOE_FUSION: + if do_finalize: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + )) + else: + assert len( + hidden_states + ) == 3, f"hidden_states must have 3 elements, but got {len(hidden_states)}" + + fc2_output = hidden_states[0] + expert_scale_factor = hidden_states[1] + expanded_idx_to_permuted_idx = hidden_states[2] + + moe_all_reduce_params = MoEAllReduceParams( + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + expert_scale_factor=expert_scale_factor, + shared_expert_output=None, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + is_cutlass_min_latency=False, + ) + hidden_states, residual = self.moe_allreduce( + fc2_output, all_reduce_params=moe_all_reduce_params) + + else: + if spec_metadata and spec_metadata.is_layer_capture(self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + return hidden_states, residual + + +class Qwen3NextAttention(Qwen3Attention): + + def __init__(self, model_config: ModelConfig[Qwen3NextConfig], + layer_idx: int, fuse_qk_norm_rope: bool): + super().__init__(model_config, + layer_idx, + fuse_qk_norm_rope=fuse_qk_norm_rope, + attn_output_gate=True, + use_gemma_rms_norm=True) + + +class Qwen3NextFullAttentionDecoderLayer(DecoderLayer): + + def __init__(self, model_config: ModelConfig[Qwen3NextConfig], + layer_idx: int, aux_stream: torch.cuda.Stream): + super().__init__() + self.model_config = model_config + config = model_config.pretrained_config + + self.self_attn = Qwen3NextAttention( + model_config, + layer_idx=layer_idx, + fuse_qk_norm_rope=False, + ) + self.mapping = model_config.mapping + self.enable_attention_dp = self.mapping.enable_attention_dp + + self.mlp = Qwen3NextSparseMoeBlock(model_config, + aux_stream, + layer_idx=layer_idx) + + use_gemma_rms_norm = True + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + use_gemma_rms_norm=use_gemma_rms_norm) + + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + use_gemma_rms_norm=use_gemma_rms_norm) + self.layer_idx = layer_idx + + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + self.next_layer_layernorm: RMSNorm = None + + self.fusion_config = EagerFusionConfig() + self.enable_fusion = os.environ.get( + "TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0" + self.enable_fusion &= not self.enable_attention_dp + + self.mapping.has_tp() + has_pp = self.mapping.has_pp() + + # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp + self.fusion_config.PRE_MOE_FUSION = False + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping) + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not self.disable_attn_allreduce), + **kwargs, + ) + + if self.fusion_config.PRE_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + )) + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now + do_finalize = not (hidden_states.shape[0] + <= self.moe_allreduce.max_token + and self.fusion_config.POST_MOE_FUSION + and self.model_config.moe_backend == 'TRTLLM' + and self.mlp.experts.has_nvfp4) + + hidden_states = self.mlp( + hidden_states, + attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + do_finalize=do_finalize, + ) + + if self.fusion_config.POST_MOE_FUSION: + if do_finalize: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + )) + else: + assert len( + hidden_states + ) == 3, f"hidden_states must have 3 elements, but got {len(hidden_states)}" + + fc2_output = hidden_states[0] + expert_scale_factor = hidden_states[1] + expanded_idx_to_permuted_idx = hidden_states[2] + + moe_all_reduce_params = MoEAllReduceParams( + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + expert_scale_factor=expert_scale_factor, + shared_expert_output=None, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + is_cutlass_min_latency=False, + ) + hidden_states, residual = self.moe_allreduce( + fc2_output, all_reduce_params=moe_all_reduce_params) + + else: + if spec_metadata and spec_metadata.is_layer_capture(self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "full_attention": Qwen3NextFullAttentionDecoderLayer, + "linear_attention": Qwen3NextLinearDecoderLayer, +} + + +class Qwen3NextModel(DecoderModel): + + def __init__(self, model_config: ModelConfig[Qwen3NextConfig]): + super().__init__(model_config) + config = self.model_config + pretrained_config = self.model_config.pretrained_config + self.aux_stream = torch.cuda.Stream() + self.preload_weight_modules = [] + if config.moe_backend == "TRTLLM": + self.preload_weight_modules = [ + "experts", + "routing_method", + "all_reduce", + ] + + if model_config.mapping.enable_attention_dp: + # When attention_dp is enabled, we cannot do all_reduce since + # the problem size of different ranks are different. + # So, we don't do parallelism here. + self.embed_tokens = Embedding(pretrained_config.vocab_size, + pretrained_config.hidden_size, + dtype=pretrained_config.torch_dtype) + else: + self.embed_tokens = Embedding( + pretrained_config.vocab_size, + pretrained_config.hidden_size, + dtype=pretrained_config.torch_dtype, + mapping=config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) + + self.layers = nn.ModuleList([ + ALL_DECODER_LAYER_TYPES[pretrained_config.layer_types[layer_idx]]( + model_config, + layer_idx, + self.aux_stream, + ) for layer_idx in range(pretrained_config.num_hidden_layers) + ]) + + use_gemma_rms_norm = True + self.norm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype, + use_gemma_rms_norm=use_gemma_rms_norm, + ) + + self.mamba_metadata: Optional[Mamba2Metadata] = None + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests: + self.mamba_metadata = Mamba2Metadata( + attn_metadata.max_num_requests, + # chunk_size=self.model_config.pretrained_config.mamba2_chunk_size) + # TODO check how to get the correct chunk_size + chunk_size=128) + self.mamba_metadata.prepare(attn_metadata) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + residual = None + for decoder_layer in self.layers: + hidden_states, residual = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + spec_metadata=spec_metadata, + mamba_metadata=self.mamba_metadata) + return hidden_states + + +@register_auto_model("Qwen3NextForCausalLM") +class Qwen3NextForCausalLM(SpecDecOneEngineForCausalLM[Qwen3NextModel, + Qwen3NextConfig]): + + def __init__( + self, + model_config: ModelConfig[Qwen3NextConfig], + ): + super().__init__( + Qwen3NextModel(model_config), + model_config, + ) + self.preload_weight_modules = self.model.preload_weight_modules + + def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): + new_weights = weight_mapper.preprocess_weights(weights) + super().load_weights(new_weights, weight_mapper) + + for idx, layer in enumerate( + self.model.layers[:self.config.num_hidden_layers]): + if idx == self.config.num_hidden_layers - 1: + layer.next_layer_layernorm = self.model.norm + else: + layer.next_layer_layernorm = self.model.layers[ + idx + 1].input_layernorm diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index c30ffcb89c1..d46d82daa7e 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -955,6 +955,9 @@ def load_single_module(name, module): module, module_name, module_weights) elif hasattr(module, 'load_weights'): + if "linear_attn.conv1d" in name: + module_weights['weight'] = module_weights[ + 'weight'].squeeze(dim=1) module.load_weights(weights=[module_weights]) else: for n, p in module._parameters.items(): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 35e5b221fc4..ff357f369d0 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -117,6 +117,7 @@ def __init__( q_scaling: float = 1.0, attention_chunk_size: Optional[int] = None, disable_deep_gemm: bool = False, + attn_output_gate: Optional[bool] = None, ): """ Initialize the Attention module. @@ -136,6 +137,7 @@ def __init__( q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. attention_chunk_size (Optional[int]): See [Chunked Attention] below. disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8). + attn_output_gate (Optional[bool]): Determines whether to use an output gate in the attention Op. If False, the decision is automatically handled by the attention backend based on its capabilities. """ super().__init__() self.layer_idx = layer_idx @@ -162,6 +164,11 @@ def __init__( self.pos_embd_params = pos_embd_params self.dense_bias = dense_bias self.q_scaling = q_scaling + self.attn_output_gate = attn_output_gate + + if self.attn_output_gate: + logger.warning_once("using attn output gate!", + key="attn_output_gate") # [Chunked Attention] # Chunked attention is applied to context requests only. Chunked attention will be @@ -207,7 +214,8 @@ def __init__( self.qkv_proj = Linear( self.hidden_size, - tp_size * self.q_size + 2 * tp_size * self.kv_size, + tp_size * self.q_size * (1 + (1 if self.attn_output_gate else 0)) + + 2 * tp_size * self.kv_size, bias=bias, dtype=dtype, mapping=mapping, @@ -511,6 +519,17 @@ def forward( if qkv_lora is not None: qkv = qkv + qkv_lora + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + ### TODO: avoid the redundant split and concat + qkv = torch.concat([q, k, v], dim=-1) + q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) @@ -518,17 +537,21 @@ def forward( if attention_sinks is not None: assert self.attn_backend == "TRTLLM", "Attention sinks are only supported for TRTLLM backend." - output = self.forward_impl(q, - k, - v, - attn_metadata, - attention_mask, - attention_window_size, - attention_mask_data, - mrope_config=mrope_config, - attention_sinks=attention_sinks) - - attn_output = self.o_proj(output, + attn_output = self.forward_impl(q, + k, + v, + attn_metadata, + attention_mask, + attention_window_size, + attention_mask_data, + mrope_config=mrope_config, + attention_sinks=attention_sinks) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params, lora_params=lora_params, layer_idx=self.layer_idx) @@ -551,7 +574,8 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], """ # If RoPE is fused into the attention OP, do not apply RoPE here. if not self.rope_fusion and position_ids is not None: - q, k, v = self.split_qkv(q, k, v) + if k is None and v is None: + q, k, v = self.split_qkv(q, k, v) q, k = self.rotary_emb(position_ids, [q, k]) return q, k, v diff --git a/tensorrt_llm/_torch/modules/fla/__init__.py b/tensorrt_llm/_torch/modules/fla/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/modules/fla/chunk.py b/tensorrt_llm/_torch/modules/fla/chunk.py new file mode 100644 index 00000000000..6f908f74077 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/chunk.py @@ -0,0 +1,238 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/chunk.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import rearrange + +from tensorrt_llm._torch.modules.fla.chunk_delta_h import \ + chunk_gated_delta_rule_fwd_h +from tensorrt_llm._torch.modules.fla.chunk_o import chunk_fwd_o +from tensorrt_llm._torch.modules.fla.chunk_scaled_dot_kkt import \ + chunk_scaled_dot_kkt_fwd +from tensorrt_llm._torch.modules.fla.cumsum import chunk_local_cumsum +from tensorrt_llm._torch.modules.fla.l2norm import l2norm_fwd +from tensorrt_llm._torch.modules.fla.solve_tril import solve_tril +from tensorrt_llm._torch.modules.fla.utils import (SUPPRESS_LEVEL, + autocast_custom_fwd, + input_guard) +from tensorrt_llm._torch.modules.fla.wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + pass + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert ( + q.dtype != torch.float32 + ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert ( + len(beta.shape) == 3 + ), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.") + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), + (q, k, v, beta, g)) + # if not head_first and q.shape[1] < q.shape[2]: + # warnings.warn( + # f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + # "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + # "when head_first=False was specified. " + # "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + # ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py b/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py new file mode 100644 index 00000000000..873a386ecfb --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py @@ -0,0 +1,294 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/chunk_delta_h.py +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import (prepare_chunk_indices, + prepare_chunk_offsets) +from tensorrt_llm._torch.modules.fla.op import exp, safe_exp +from tensorrt_llm._torch.modules.fla.utils import is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +# @triton.autotune( +# configs=[ +# triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) +# for num_warps in [2, 4] +# for num_stages in [2, 3, 4] +# for BV in [32, 64] +# ], +# key=["H", "K", "V", "BT", "USE_G"], +# use_cuda_graph=use_cuda_graph, +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, + b_h2.to(p_h2.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, + b_h3.to(p_h3.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, + b_h4.to(p_h4.dtype.element_ty), + boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_v_new = (tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), + (1, 0)) if SAVE_NEW_VALUE else None) + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, + b_v_new.to(p_v_new.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h2.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h3.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h4.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None else None) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = (k.new_empty(N, H, K, V, dtype=torch.float32) + if output_final_state else None) + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=32, + num_warps=4, + num_stages=2, + ) + return h, v_new, final_state diff --git a/tensorrt_llm/_torch/modules/fla/chunk_o.py b/tensorrt_llm/_torch/modules/fla/chunk_o.py new file mode 100644 index 00000000000..61be1051c25 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/chunk_o.py @@ -0,0 +1,169 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/chunk_o.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices +from tensorrt_llm._torch.modules.fla.op import exp, safe_exp +from tensorrt_llm._torch.modules.fla.utils import (check_shared_mem, + is_nvidia_hopper) + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +# @triton.autotune( +# configs=[ +# triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) +# for BK in BKV_LIST +# for BV in BKV_LIST +# for num_warps in NUM_WARPS +# for num_stages in [2, 3, 4] +# ], +# key=["H", "K", "V", "BT"], +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), + (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), + (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), + (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1]**-0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=64, + num_warps=4, + num_stages=2, + ) + return o diff --git a/tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py b/tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py new file mode 100644 index 00000000000..5a6e271de94 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py @@ -0,0 +1,143 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices +from tensorrt_llm._torch.modules.fla.op import safe_exp + + +@triton.heuristics({ + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "USE_G": lambda args: args["g_cumsum"] is not None, +}) +# @triton.autotune( +# configs=[ +# triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) +# for BK in [32, 64, 128] +# for num_warps in [2, 4, 8] +# for num_stages in [2, 3, 4] +# ], +# key=["H", "K", "BT", "IS_VARLEN"], +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = tl.arange(0, BT) + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=64, + num_warps=8, + num_stages=3, + ) + return A diff --git a/tensorrt_llm/_torch/modules/fla/cumsum.py b/tensorrt_llm/_torch/modules/fla/cumsum.py new file mode 100644 index 00000000000..a3bd49dcc01 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/cumsum.py @@ -0,0 +1,282 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/cumsum.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices +from tensorrt_llm._torch.modules.fla.utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +# @triton.autotune( +# configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], +# key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + # [BT] + b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, )) + + +@triton.heuristics({ + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=8, + num_stages=3, + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None else None) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert (g.shape[0] == 1 + ), "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError(f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise") diff --git a/tensorrt_llm/_torch/modules/fla/fused_recurrent.py b/tensorrt_llm/_torch/modules/fla/fused_recurrent.py new file mode 100644 index 00000000000..c93d9e512bd --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/fused_recurrent.py @@ -0,0 +1,622 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/fused_recurrent.py +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.op import exp +from tensorrt_llm._torch.modules.fla.utils import input_guard + + +@triton.heuristics({ + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, HV, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + ) + + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps.") + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + return o, final_state + + +@triton.heuristics({ + "USE_INITIAL_STATE": + lambda args: args["h0_source"] is not None, + "IS_VARLEN": + lambda args: args["cu_seqlens"] is not None, + "CACHE_INTERMEDIATE_STATES": + lambda args: args["intermediate_states_buffer"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def fused_recurrent_gated_delta_rule_update_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + intermediate_states_buffer, + cache_steps, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + DISABLE_STATE_UPDATE: tl.constexpr, # whether to disable final state update + DISABLE_OUTPUT_CALCULATION: tl. + constexpr, # whether to disable output calculation + CACHE_INTERMEDIATE_STATES: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + # Add bounds checking for idx + if idx >= 0: # Assuming negative indices are invalid + p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + # Prepare intermediate state cache variables if enabled + cache_idx = -1 + if CACHE_INTERMEDIATE_STATES: + cache_idx = tl.load(h0_indices + i_n) + + step_idx = 0 + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + if not DISABLE_OUTPUT_CALCULATION: + b_o = tl.sum(b_h * b_q[:, None], 0) + # core attn output + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # store intermediate states if enabled + if CACHE_INTERMEDIATE_STATES: + if cache_idx >= 0: + # Compute cache pointer for this step + step_offset = step_idx * HV * K * V + cache_ptr = (intermediate_states_buffer + + cache_idx * cache_steps * HV * K * V + + step_offset + i_hv * K * V + o_k[:, None] * V + + o_v[None, :]) + tl.store(cache_ptr, + b_h.to(cache_ptr.dtype.element_ty), + mask=mask_h) + + step_idx += 1 + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + # Store final state back to h0_source with bounds checking + # ssm states + if not DISABLE_STATE_UPDATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: # Add bounds checking + p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_update_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + disable_state_update: bool = False, + disable_output_calculation: bool = False, + intermediate_states_buffer: Optional[torch.Tensor] = None, + cache_steps: Optional[int] = None, +) -> torch.Tensor: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + if disable_output_calculation: + # When output calculation is disabled, allocate minimal tensor + o = q.new_empty(NK, 1, 1, 1, 1) # minimal allocation + else: + o = q.new_empty(NK, *v.shape) + + grid = (NK, NV, N * HV) + + fused_recurrent_gated_delta_rule_update_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0_source=initial_state_source, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + intermediate_states_buffer=intermediate_states_buffer, + cache_steps=0 if cache_steps is None else cache_steps, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + DISABLE_STATE_UPDATE=disable_state_update, + DISABLE_OUTPUT_CALCULATION=disable_output_calculation, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o + + +class FusedRecurrentUpdateFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + disable_state_update: bool = False, + disable_output_calculation: bool = False, + intermediate_states_buffer: Optional[torch.Tensor] = None, + cache_steps: Optional[int] = None, + ): + o = fused_recurrent_gated_delta_rule_update_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + disable_state_update=disable_state_update, + disable_output_calculation=disable_output_calculation, + intermediate_states_buffer=intermediate_states_buffer, + cache_steps=cache_steps, + ) + + return o + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps.") + + +def fused_recurrent_gated_delta_rule_update( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state_source: torch.Tensor = None, + initial_state_indices: torch.Tensor = None, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + disable_state_update: bool = False, + disable_output_calculation: bool = False, + intermediate_states_buffer: Optional[torch.Tensor] = None, + cache_steps: Optional[int] = None, +) -> torch.Tensor: + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if (initial_state_source is not None + and initial_state_indices.shape[0] != len(cu_seqlens) - 1): + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o = FusedRecurrentUpdateFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state_source, + initial_state_indices, + cu_seqlens, + use_qk_l2norm_in_kernel, + disable_state_update, + disable_output_calculation, + intermediate_states_buffer, + cache_steps, + ) + return o diff --git a/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py new file mode 100644 index 00000000000..70589b762de --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py @@ -0,0 +1,222 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.utils import input_guard + + +@triton.heuristics({ + "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + dt_bias, + softplus_beta, + softplus_threshold, + q, + k, + v, + b, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel that combines sigmoid gating computation with recurrent delta rule update. + """ + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + # Gating computation pointers + p_A_log = A_log + i_hv + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + # Load inputs + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) + + # Compute sigmoid gating + # Load gating parameters + b_A_log = tl.load(p_A_log).to(tl.float32) + b_a = tl.load(p_a).to(tl.float32) + b_dt_bias = tl.load(p_dt_bias).to(tl.float32) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + # Apply softplus with numerical stability + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Apply L2 normalization if enabled + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + + b_q = b_q * scale + + # Apply gating to hidden state: h *= exp(g) + b_h *= tl.exp(b_g) + + # Delta rule: v -= sum(h * k, dim=0) + b_v -= tl.sum(b_h * b_k[:, None], 0) + + # Apply beta gating: v *= beta + b_v *= b_beta + + # Update hidden state: h += k[:, None] * v[None, :] + b_h += b_k[:, None] * b_v[None, :] + + # Compute output: o = sum(h * q, dim=0) + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # Update pointers for next timestep + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_b += HV + p_a += HV + + # Store final state back to h0_source with bounds checking + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +@input_guard +def fused_sigmoid_gating_delta_rule_update( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float, + softplus_threshold: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: Optional[float] = None, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +): + """ + Fused triton implementation of sigmoid gating delta rule update. + This function uses a single fused kernel that combines both sigmoid gating computation + and the recurrent delta rule update for better performance. + """ + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + + o = q.new_empty(NK, *v.shape) + grid = (NK, NV, N * HV) + + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + o=o, + h0_source=initial_state_source, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o diff --git a/tensorrt_llm/_torch/modules/fla/index.py b/tensorrt_llm/_torch/modules/fla/index.py new file mode 100644 index 00000000000..d320fdf0d07 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/index.py @@ -0,0 +1,32 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/index.py +# -*- coding: utf-8 -*- + +import torch +import triton + +from tensorrt_llm._torch.modules.fla.utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + indices = torch.cat([ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + ]).cumsum(-1) diff --git a/tensorrt_llm/_torch/modules/fla/l2norm.py b/tensorrt_llm/_torch/modules/fla/l2norm.py new file mode 100644 index 00000000000..01abe7cfea5 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/l2norm.py @@ -0,0 +1,152 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/l2norm.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.utils import input_guard + +BT_LIST = [8, 16, 32, 64, 128] + + +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] +# ], +# key=["D"], +# ) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +# @triton.autotune( +# configs=[ +# triton.Config({"BT": BT}, num_warps=num_warps) +# for num_warps in [1, 2, 4, 8, 16] +# for BT in BT_LIST +# ], +# key=["D", "NB"], +# ) +@triton.jit +def l2norm_fwd_kernel( + x, + y, + eps, + NB: tl.constexpr, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]), ) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + BT=16, + num_warps=8, + num_stages=3, + ) + else: + l2norm_fwd_kernel1[(T, )]( + x, + y, + eps=eps, + D=D, + BD=BD, + num_warps=8, + num_stages=3, + ) + + return y.view(x_shape_og) + + +class L2NormFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x, eps=1e-6, output_dtype=None): + return l2norm_fwd(x, eps, output_dtype) + + +def l2norm(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + return L2NormFunction.apply(x, eps, output_dtype) + + +l2_norm = l2norm + + +class L2Norm(nn.Module): + + def __init__(self, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + super().__init__() + self.eps = eps + self.output_dtype = output_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return l2norm(x, self.eps, self.output_dtype) diff --git a/tensorrt_llm/_torch/modules/fla/layernorm_gated.py b/tensorrt_llm/_torch/modules/fla/layernorm_gated.py new file mode 100644 index 00000000000..86b66c941dc --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/layernorm_gated.py @@ -0,0 +1,330 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/layernorm_gated.py + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + + +def rms_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True, +): + dtype = x.dtype + x.shape[-1] + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * + weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + if not is_rms_norm else None) + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.get_device_module(x.device).device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + return y.reshape(x_shape_og) + + +def layernorm_fn( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, True) + + +class LayerNorm(torch.nn.Module): + + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) + + +class RMSNorm(torch.nn.Module): + + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return rmsnorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) diff --git a/tensorrt_llm/_torch/modules/fla/op.py b/tensorrt_llm/_torch/modules/fla/op.py new file mode 100644 index 00000000000..4f97a9ebf46 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/op.py @@ -0,0 +1,65 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/op.py +# -*- coding: utf-8 -*- + +import os + +import triton +import triton.language as tl +import triton.language.extra.libdevice as tldevice + +from tensorrt_llm._torch.modules.fla.utils import is_gather_supported + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": + exp = tldevice.fast_expf + exp2 = tldevice.exp2 + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + exp = tl.exp + exp2 = tl.math.exp2 + log = tl.log + log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None + +else: + gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/tensorrt_llm/_torch/modules/fla/solve_tril.py b/tensorrt_llm/_torch/modules/fla/solve_tril.py new file mode 100644 index 00000000000..d8bf84cca28 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/solve_tril.py @@ -0,0 +1,426 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/solve_tril.py +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices +from tensorrt_llm._torch.modules.fla.utils import input_guard + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=num_warps, num_stages=num_stages) +# for num_warps in [1, 2, 4, 8] +# for num_stages in [2, 3, 4, 5] +# ], +# key=["BT"], +# ) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), + (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), + (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where( + tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=num_warps, num_stages=num_stages) +# for num_warps in [1, 2, 4, 8] +# for num_stages in [2, 3, 4, 5] +# ], +# key=["H", "BT", "IS_VARLEN"], +# ) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), + Ai_11, + input_precision="ieee") + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=num_warps, num_stages=num_stages) +# for num_warps in [2, 4, 8] +# for num_stages in [2, 3, 4, 5] +# ], +# key=["H", "BT", "IS_VARLEN"], +# ) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), + Ai_11, + input_precision="ieee") + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), + Ai_22, + input_precision="ieee") + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), + Ai_33, + input_precision="ieee") + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), + (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), + (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), + (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), + (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), + (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), + (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), + (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), + (16, 16), (1, 0)) + tl.store( + p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, + T, + H, + 16, + device=A.device, + dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = (prepare_chunk_indices(cu_seqlens, 16) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + num_warps=1, + num_stages=4, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = (merge_16x16_to_32x32_inverse_kernel + if BT == 32 else merge_16x16_to_64x64_inverse_kernel) + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + num_warps=4, + num_stages=3, + ) + return Ai diff --git a/tensorrt_llm/_torch/modules/fla/utils.py b/tensorrt_llm/_torch/modules/fla/utils.py new file mode 100644 index 00000000000..5358ecaee33 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/utils.py @@ -0,0 +1,326 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/utils.py +# -*- coding: utf-8 -*- + +import contextlib +import functools +import logging +import os +import sys +from enum import Enum +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +import torch +import triton +from packaging import version + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" + + +@lru_cache(maxsize=1) +def check_environments(): + """ + Checks the current operating system, Triton version, and Python version, + issuing warnings if they don't meet recommendations. + This function's body only runs once due to lru_cache. + """ + # Check Operating System + if sys.platform == "win32": + logger.warning( + "Detected Windows operating system. Triton does not have an official Windows release, " + "thus FLA will not be adapted for Windows, and any potential errors will not be fixed. " + "Please consider using a Linux environment for compatibility.") + + triton_version = version.parse(triton.__version__) + required_triton_version = version.parse("3.2.0") + + if triton_version < required_triton_version: + logger.warning( + f"Current Triton version {triton_version} is below the recommended 3.2.0 version. " + "Errors may occur and these issues will not be fixed. " + "Please consider upgrading Triton.") + + # Check Python version + py_version = version.parse( + f"{sys.version_info.major}.{sys.version_info.minor}") + required_py_version = version.parse("3.11") + + if py_version < required_py_version: + logger.warning( + f"Current Python version {py_version} is below the recommended 3.11 version. " + "It is recommended to upgrade to Python 3.11 or higher for the best experience." + ) + + return None + + +check_environments() + + +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + logger.info(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)): + if error_rate > ratio: + import warnings + + warnings.warn(msg) + else: + assert error_rate < ratio, msg + + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + if all(a is b for a, b in zip(args, last_args)) and all( + k in last_kwargs and v is last_kwargs[k] + for k, v in kwargs.items()): + cache_entries = (cache_entries[:i] + cache_entries[i + 1:] + + [(args, kwargs, last_result)]) + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else + i.contiguous() for i in args) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +contiguous = input_guard + + +def require_version(version, hint): + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + """ + + def decorator(fn): + + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + from transformers.utils.versions import require_version + + require_version(version, hint) + return fn( + ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() + for i in args), + **{ + k: + (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + }, + ) + + return wrapper + + return decorator + + +def checkpoint(fn): + + def wrapper(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) + + return wrapper + + +@lru_cache(maxsize=None) +def check_pytorch_version(version_s: str = "2.4") -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +def _cpu_device_warning(): + import warnings + + warnings.warn( + ("Triton is not supported on current platform, roll back to CPU."), + stacklevel=1) + + +@lru_cache(maxsize=None) +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + try: + return triton.runtime.driver.active.utils.get_device_properties( + tensor_idx)["multiprocessor_count"] + except BaseException: + _cpu_device_warning() + return -1 + + +@lru_cache(maxsize=None) +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + _cpu_device_warning() + return "cpu" + + +@lru_cache(maxsize=None) +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + if device == "cuda": + return "nvidia" + elif device == "hip": + return "amd" + elif device == "xpu": + return "intel" + else: + return device + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name( + 0) +is_nvidia_hopper = is_nvidia and ("NVIDIA H" in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + +# Nvidia Ampere or newer, haven't check AMD and intel yet. +is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8 +is_gather_supported = hasattr(triton.language, "gather") + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i) + ["max_shared_mem"] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + _cpu_device_warning() + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@lru_cache(maxsize=None) +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False + + +if check_pytorch_version("2.4"): + device = "cuda" if device == "cpu" else device + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, + device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, + device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) + +else: + assert (device == "cuda" + ), "Only cuda device is supported for PyTorch version < 2.4.0." + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) diff --git a/tensorrt_llm/_torch/modules/fla/wy_fast.py b/tensorrt_llm/_torch/modules/fla/wy_fast.py new file mode 100644 index 00000000000..968aa437446 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fla/wy_fast.py @@ -0,0 +1,152 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/wy_fast.py +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=num_warps, num_stages=num_stages) +# for num_warps in [2, 4, 8] +# for num_stages in [2, 3, 4] +# ], +# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +# ) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0, ))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=4, + num_stages=3, + ) + return w, u + + +fwd_recompute_w_u = recompute_w_u_fwd diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index c64b24ca693..1ccf02cddad 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -156,11 +156,20 @@ def __init__( config: ModelConfig, q_scaling: float = 1.0, disable_deep_gemm: bool = False, + use_gemma_rms_norm: bool = False, + attn_output_gate: Optional[bool] = None, ): self.pretrained_config = config.pretrained_config self.fuse_qk_norm_rope = fuse_qk_norm_rope self.skip_rope = skip_rope + if use_gemma_rms_norm: + assert fuse_qk_norm_rope is False, "fused_qk_norm_rope is not supported for gemma rms norm." + + # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope. + rope_fusion = not self.fuse_qk_norm_rope and not skip_rope + if attn_output_gate and use_gemma_rms_norm: + rope_fusion = False assert not (fuse_qk_norm_rope and skip_rope ), "Fusing qk norm and skipping rope is not supported" @@ -173,23 +182,26 @@ def __init__( pos_embd_params=pos_embd_params, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, # and self.rotary_emb will be skipped in the overridden apply_rope. - rope_fusion=not self.fuse_qk_norm_rope and not skip_rope, + rope_fusion=rope_fusion, layer_idx=layer_idx, dtype=dtype, dense_bias=dense_bias, config=config, q_scaling=q_scaling, disable_deep_gemm=disable_deep_gemm, + attn_output_gate=attn_output_gate, ) self.q_norm = RMSNorm(hidden_size=self.head_dim, eps=self.pretrained_config.rms_norm_eps, dtype=self.pretrained_config.torch_dtype, - has_weights=True) + has_weights=True, + use_gemma_rms_norm=use_gemma_rms_norm) self.k_norm = RMSNorm(hidden_size=self.head_dim, eps=self.pretrained_config.rms_norm_eps, dtype=self.pretrained_config.torch_dtype, - has_weights=True) + has_weights=True, + use_gemma_rms_norm=use_gemma_rms_norm) self.aux_stream = torch.cuda.Stream() self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index 4b7a388983b..25b025b9425 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -29,18 +29,23 @@ class RMSNorm(nn.Module): _ArgumentNotSpecifiedSentinelType: TypeAlias = EllipsisType def __init__( - self, - *, - hidden_size: int, - eps: float, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - has_weights: bool = True, + self, + *, + hidden_size: int, + eps: float, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + has_weights: bool = True, + use_gemma_rms_norm: bool = False, # Assume has_weights = True ): super().__init__() if has_weights: - self.weight = nn.Parameter( - torch.ones(hidden_size, dtype=dtype, device=device)) + if not use_gemma_rms_norm: + self.weight = nn.Parameter( + torch.ones(hidden_size, dtype=dtype, device=device)) + else: + self.weight = nn.Parameter( + torch.zeros(hidden_size, dtype=dtype, device=device)) else: self.register_buffer('weight', torch.ones(hidden_size, @@ -48,6 +53,7 @@ def __init__( device=device), persistent=False) self.variance_epsilon = eps + self.use_gemma_rms_norm = use_gemma_rms_norm def forward( self, @@ -63,13 +69,26 @@ def forward( if IS_FLASHINFER_AVAILABLE: from ..custom_ops import (flashinfer_fused_add_rmsnorm, + flashinfer_gemma_fused_add_rmsnorm, + flashinfer_gemma_rmsnorm, flashinfer_rmsnorm) if residual is not None: - flashinfer_fused_add_rmsnorm(hidden_states, residual, - self.weight, self.variance_epsilon) + if not self.use_gemma_rms_norm: + flashinfer_fused_add_rmsnorm(hidden_states, residual, + self.weight, + self.variance_epsilon) + else: + flashinfer_gemma_fused_add_rmsnorm(hidden_states, residual, + self.weight, + self.variance_epsilon) else: - hidden_states = flashinfer_rmsnorm(hidden_states, self.weight, - self.variance_epsilon) + if not self.use_gemma_rms_norm: + hidden_states = flashinfer_rmsnorm(hidden_states, + self.weight, + self.variance_epsilon) + else: + hidden_states = flashinfer_gemma_rmsnorm( + hidden_states, self.weight, self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -80,7 +99,11 @@ def forward( variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight * hidden_states.to(input_dtype) + if not self.use_gemma_rms_norm: + hidden_states = self.weight * hidden_states.to(input_dtype) + else: + hidden_states = (self.weight + + 1) * hidden_states.to(input_dtype) if return_residual: return hidden_states, cast(Optional[torch.Tensor], residual) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 2902c5f06b7..f0a74f0fede 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -23,7 +23,7 @@ from ..model_config import ModelConfig from ..speculative import get_num_extra_kv_tokens, get_spec_decoder from .config import PyTorchConfig -from .config_utils import is_mla, is_nemotron_hybrid +from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver @@ -452,6 +452,56 @@ def _create_kv_cache_manager( dtype=kv_cache_dtype, spec_config=spec_config, ) + elif is_qwen3_next(config): + if self._max_beam_width > 1: + raise ValueError( + "MambaHybridCacheManager + beam search is not supported yet." + ) + + if not estimating_kv_cache and self._kv_connector_manager is not None: + raise NotImplementedError( + "Connector manager is not supported for MambaHybridCacheManager." + ) + config = model_engine.model.model_config.pretrained_config + mamba_layer_mask = [ + True if i % config.full_attention_interval + != config.full_attention_interval - 1 else False + for i in range(num_hidden_layers) + ] + layer_mask = [ + False if i % config.full_attention_interval + != config.full_attention_interval - 1 else True + for i in range(num_hidden_layers) + ] + num_mamba_layers = num_hidden_layers // config.full_attention_interval * ( + config.full_attention_interval - 1) + num_layers = num_hidden_layers - num_mamba_layers + kv_cache_manager = MambaHybridCacheManager( + # mamba cache parameters + config.linear_key_head_dim, + config.linear_conv_kernel_dim, + config.linear_num_value_heads, + config.linear_num_key_heads, + config.linear_value_head_dim, + num_mamba_layers, + mamba_layer_mask, + config.torch_dtype, + model_engine.model.model_config.quant_config. + mamba_ssm_cache_dtype, + # kv cache parameters + self._kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=num_layers, + layer_mask=layer_mask, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + tokens_per_block=self._tokens_per_block, + max_seq_len=self._max_seq_len, + max_batch_size=self._max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=spec_config, + ) else: # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager is_vswa = self._kv_cache_config.max_attention_window is not None and len( diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 6981a593140..8bf201147f0 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -11,3 +11,7 @@ def is_mla(config): config, "qk_rope_head_dim", None): return True return False + + +def is_qwen3_next(config): + return getattr(config, 'linear_key_head_dim', 0) > 0