Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 193 additions & 22 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar

import torch
from torch import nn
from transformers import PretrainedConfig

from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
Expand Down Expand Up @@ -73,6 +75,13 @@
is_npu,
)

_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import fused_qk_norm_rope

TConfig = TypeVar("TConfig", bound=PretrainedConfig)

Qwen3MoeConfig = None

_is_flashinfer_available = is_flashinfer_available()
Expand All @@ -85,6 +94,118 @@
from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope


def compute_yarn_parameters(
config: PretrainedConfig,
) -> tuple[float, float, float, float]:
"""
Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://huggingface.co/papers/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
Returns:
factor: float, the scaling factor for the RoPE embeddings
low: float, the lower bound of the dimension range
high: float, the upper bound of the dimension range
attention_factor: float, the post-processing scaling factor applied to the computed cos/sin
"""

# The config does not contain rope_scaling, which means the model is not using yarn
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return 1.0, 0, 0, 1.0

base = config.rope_theta
partial_rotary_factor = (
config.partial_rotary_factor
if hasattr(config, "partial_rotary_factor")
else 1.0
)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
dim = int(head_dim * partial_rotary_factor)
factor = getattr(rope_scaling, "factor", 1.0)
attention_factor = rope_scaling.get("attention_factor")
mscale = rope_scaling.get("mscale")
mscale_all_dim = rope_scaling.get("mscale_all_dim")

if "original_max_position_embeddings" in rope_scaling:
original_max_position_embeddings = rope_scaling[
"original_max_position_embeddings"
]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings

def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

# Sets the attention factor as suggested in the paper
if attention_factor is None:
if mscale and mscale_all_dim:
attention_factor = float(
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)
)
else:
attention_factor = get_mscale(factor)

# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = rope_scaling.get("beta_fast") or 32
beta_slow = rope_scaling.get("beta_slow") or 1

# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
) / (2 * math.log(base))

def find_correction_range(
low_rot, high_rot, dim, base, max_position_embeddings, truncate
):
"""Find dimension range bounds based on rotations"""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)

truncate = rope_scaling.get("truncate", True)
low, high = find_correction_range(
beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate
)

# These parts are implemented in the fusedQKNormRopeKernel.cu
# # def linear_ramp_factor(min, max, dim):
# # if min == max:
# # max += 0.001 # Prevent singularity

# # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
# # ramp_func = torch.clamp(linear_func, 0, 1)
# # return ramp_func

# # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# # to expand the possible context length. In other words, interpolation = apply scaling factor.
# # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
# # inv_freq_extrapolation = 1.0 / pos_freqs
# # inv_freq_interpolation = 1.0 / (factor * pos_freqs)

# # # Get n-dimensional rotational scaling corrected for extrapolation
# # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
# # inv_freq = (
# # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
# # + inv_freq_extrapolation * inv_freq_extrapolation_factor
# # )
# # return inv_freq, attention_factor
return factor, low, high, attention_factor


class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -286,6 +407,7 @@ def __init__(
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
attention_bias: bool = False,
config: Optional[TConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
Expand All @@ -297,6 +419,7 @@ def __init__(
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()

self.config = config
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
Expand Down Expand Up @@ -352,6 +475,14 @@ def __init__(
self.compatible_with_fused_kv_buffer = (
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
)
self.compatible_with_fused_qk_norm_rope = (
not isinstance(self.rotary_emb, MRotaryEmbedding)
) and self.head_dim in (64, 128, 256)
self.use_fused_qk_norm_rope = (
get_global_server_args().enable_fused_qk_norm_rope
and self.compatible_with_fused_qk_norm_rope
)
self._used_fused_qk_norm_rope_last_call = False

self.attn = RadixAttention(
self.num_heads,
Expand Down Expand Up @@ -379,6 +510,9 @@ def _apply_qk_norm(
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
q = q_by_head.view(q.shape)
k = k_by_head.view(k.shape)
return q, k
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
Expand Down Expand Up @@ -433,27 +567,61 @@ def forward_prepare_native(
forward_batch: ForwardBatch,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)

q, k, v = self.apply_qk_norm_rope(qkv, positions, forward_batch)

inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state

def apply_qk_norm_rope(self, qkv, positions, forward_batch):
use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16
if use_fused:
theta = getattr(self.config, "rope_theta", 10000.0)
positions = (
positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous()
)
factor, low, high, attention_factor = compute_yarn_parameters(self.config)
fused_qk_norm_rope(
qkv,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
theta,
self.rotary_emb.is_neox_style,
positions,
factor,
low,
high,
attention_factor,
)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
self._used_fused_qk_norm_rope_last_call = True
else:
# Fallback to non-fused QK Norm & RoPE implementation
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)
self._used_fused_qk_norm_rope_last_call = False
return q, k, v

def forward_prepare(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -482,15 +650,17 @@ def forward_core(self, intermediate_state):

q, k, v, fb = inner_state

must_save_kv = self._used_fused_qk_norm_rope_last_call
save_kv_cache = must_save_kv or not (
enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
)
attn_output = self.attn(
q,
k,
v,
fb,
save_kv_cache=not (
enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
),
save_kv_cache=save_kv_cache,
)
output, _ = self.o_proj(attn_output)
return output
Expand Down Expand Up @@ -543,6 +713,7 @@ def __init__(
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
config=config,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
dual_chunk_attention_config=dual_chunk_attention_config,
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ class ServerArgs:
enable_attn_tp_input_scattered: bool = False
# Context parallelism used in the long sequence prefill phase of DeepSeek v3.2
enable_nsa_prefill_context_parallel: bool = False
enable_fused_qk_norm_rope: bool = False

# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
Expand Down Expand Up @@ -3738,6 +3739,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable context parallelism used in the long sequence prefill phase of DeepSeek v3.2.",
)
parser.add_argument(
"--enable-fused-qk-norm-rope",
action="store_true",
help="Enable fused qk normalization and rope rotary embedding.",
)

# Dynamic batch tokenizer
parser.add_argument(
Expand Down
Loading