diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 3bd986e8d1e5..83ae6ce1049b 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -2068,6 +2068,7 @@ def get_rope_wrapper( dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, device: Optional[str] = None, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ): if device != "cpu": wrapper = aiter_get_rope if _use_aiter else get_rope @@ -2080,6 +2081,7 @@ def get_rope_wrapper( rope_scaling, dtype, partial_rotary_factor, + dual_chunk_attention_config, ) return get_rope_cpu( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index f03a51abde05..8a0feaf20b32 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -51,7 +51,7 @@ from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope_wrapper from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -66,6 +66,7 @@ ) from sglang.srt.utils import ( add_prefix, + get_bool_env_var, is_cuda, is_hip, get_bool_env_var, @@ -353,7 +354,11 @@ def __init__( prefix=add_prefix("o_proj", prefix), ) - self.rotary_emb = get_rope( + self.rope_scaling = rope_scaling + if _use_aiter and self.rope_scaling is not None: + self.rope_scaling["aiter_rope_fused_qknorm"] = True + + self.rotary_emb = get_rope_wrapper( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, @@ -424,23 +429,35 @@ def forward_prepare( ): return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=( - create_fused_set_kv_buffer_arg( - value=v, - layer=self.attn, - forward_batch=forward_batch, - ) - if enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - else None - ), - ) + if _use_aiter and self.rope_scaling is not None and "aiter_rope_fused_qknorm" in self.rope_scaling: + assert self.k_norm.variance_epsilon == self.q_norm.variance_epsilon + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.k_norm.variance_epsilon, + ) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + else None + ), + ) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state