diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 7caec156e700..9737ac7197a8 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -18,10 +18,12 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar import torch from torch import nn +from transformers import PretrainedConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, @@ -73,6 +75,13 @@ is_npu, ) +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import fused_qk_norm_rope + +TConfig = TypeVar("TConfig", bound=PretrainedConfig) + Qwen3MoeConfig = None _is_flashinfer_available = is_flashinfer_available() @@ -85,6 +94,118 @@ from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope +def compute_yarn_parameters( + config: PretrainedConfig, +) -> tuple[float, float, float, float]: + """ + Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + Returns: + factor: float, the scaling factor for the RoPE embeddings + low: float, the lower bound of the dimension range + high: float, the upper bound of the dimension range + attention_factor: float, the post-processing scaling factor applied to the computed cos/sin + """ + + # The config does not contain rope_scaling, which means the model is not using yarn + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return 1.0, 0, 0, 1.0 + + base = config.rope_theta + partial_rotary_factor = ( + config.partial_rotary_factor + if hasattr(config, "partial_rotary_factor") + else 1.0 + ) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + factor = getattr(rope_scaling, "factor", 1.0) + attention_factor = rope_scaling.get("attention_factor") + mscale = rope_scaling.get("mscale") + mscale_all_dim = rope_scaling.get("mscale_all_dim") + + if "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling[ + "original_max_position_embeddings" + ] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float( + get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim) + ) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = rope_scaling.get("beta_fast") or 32 + beta_slow = rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) + + def find_correction_range( + low_rot, high_rot, dim, base, max_position_embeddings, truncate + ): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + truncate = rope_scaling.get("truncate", True) + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate + ) + + # These parts are implemented in the fusedQKNormRopeKernel.cu + # # def linear_ramp_factor(min, max, dim): + # # if min == max: + # # max += 0.001 # Prevent singularity + + # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + # # ramp_func = torch.clamp(linear_func, 0, 1) + # # return ramp_func + + # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # # to expand the possible context length. In other words, interpolation = apply scaling factor. + # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + # # inv_freq_extrapolation = 1.0 / pos_freqs + # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + # # # Get n-dimensional rotational scaling corrected for extrapolation + # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + # # inv_freq = ( + # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + # # + inv_freq_extrapolation * inv_freq_extrapolation_factor + # # ) + # # return inv_freq, attention_factor + return factor, low, high, attention_factor + + class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, @@ -286,6 +407,7 @@ def __init__( head_dim: Optional[int] = None, rms_norm_eps: float = 1e-06, attention_bias: bool = False, + config: Optional[TConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", dual_chunk_attention_config: Optional[dict[str, Any]] = None, @@ -297,6 +419,7 @@ def __init__( attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() + self.config = config self.total_num_heads = num_heads assert self.total_num_heads % attn_tp_size == 0 self.num_heads = self.total_num_heads // attn_tp_size @@ -352,6 +475,14 @@ def __init__( self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True ) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope + ) + self._used_fused_qk_norm_rope_last_call = False self.attn = RadixAttention( self.num_heads, @@ -379,6 +510,9 @@ def _apply_qk_norm( k_by_head = k.reshape(-1, self.head_dim) k_by_head = self.k_norm(k_by_head) current_stream.wait_stream(self.alt_stream) + q = q_by_head.view(q.shape) + k = k_by_head.view(k.shape) + return q, k else: q_by_head = q.reshape(-1, self.head_dim) q_by_head = self.q_norm(q_by_head) @@ -433,27 +567,61 @@ def forward_prepare_native( forward_batch: ForwardBatch, ): qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=( - create_fused_set_kv_buffer_arg( - value=v, - layer=self.attn, - forward_batch=forward_batch, - ) - if enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - else None - ), - ) + + q, k, v = self.apply_qk_norm_rope(qkv, positions, forward_batch) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state + def apply_qk_norm_rope(self, qkv, positions, forward_batch): + use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16 + if use_fused: + theta = getattr(self.config, "rope_theta", 10000.0) + positions = ( + positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous() + ) + factor, low, high, attention_factor = compute_yarn_parameters(self.config) + fused_qk_norm_rope( + qkv, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.q_norm.variance_epsilon, + self.q_norm.weight, + self.k_norm.weight, + theta, + self.rotary_emb.is_neox_style, + positions, + factor, + low, + high, + attention_factor, + ) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + self._used_fused_qk_norm_rope_last_call = True + else: + # Fallback to non-fused QK Norm & RoPE implementation + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + else None + ), + ) + self._used_fused_qk_norm_rope_last_call = False + return q, k, v + def forward_prepare( self, positions: torch.Tensor, @@ -482,15 +650,17 @@ def forward_core(self, intermediate_state): q, k, v, fb = inner_state + must_save_kv = self._used_fused_qk_norm_rope_last_call + save_kv_cache = must_save_kv or not ( + enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + ) attn_output = self.attn( q, k, v, fb, - save_kv_cache=not ( - enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - ), + save_kv_cache=save_kv_cache, ) output, _ = self.o_proj(attn_output) return output @@ -543,6 +713,7 @@ def __init__( head_dim=head_dim, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, + config=config, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), dual_chunk_attention_config=dual_chunk_attention_config, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6e3af4623e56..03b3747f9dda 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -542,6 +542,7 @@ class ServerArgs: enable_attn_tp_input_scattered: bool = False # Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 enable_nsa_prefill_context_parallel: bool = False + enable_fused_qk_norm_rope: bool = False # Dynamic batch tokenizer enable_dynamic_batch_tokenizer: bool = False @@ -3738,6 +3739,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable context parallelism used in the long sequence prefill phase of DeepSeek v3.2.", ) + parser.add_argument( + "--enable-fused-qk-norm-rope", + action="store_true", + help="Enable fused qk normalization and rope rotary embedding.", + ) # Dynamic batch tokenizer parser.add_argument(