Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,7 @@ def get_rope_wrapper(
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
):
if device != "cpu":
wrapper = aiter_get_rope if _use_aiter else get_rope
Expand All @@ -2080,6 +2081,7 @@ def get_rope_wrapper(
rope_scaling,
dtype,
partial_rotary_factor,
dual_chunk_attention_config,
)

return get_rope_cpu(
Expand Down
55 changes: 36 additions & 19 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope_wrapper
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
Expand All @@ -66,6 +66,7 @@
)
from sglang.srt.utils import (
add_prefix,
get_bool_env_var,
is_cuda,
is_hip,
get_bool_env_var,
Expand Down Expand Up @@ -353,7 +354,11 @@ def __init__(
prefix=add_prefix("o_proj", prefix),
)

self.rotary_emb = get_rope(
self.rope_scaling = rope_scaling
if _use_aiter and self.rope_scaling is not None:
self.rope_scaling["aiter_rope_fused_qknorm"] = True

self.rotary_emb = get_rope_wrapper(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
Expand Down Expand Up @@ -424,23 +429,35 @@ def forward_prepare(
):
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)
if _use_aiter and self.rope_scaling is not None and "aiter_rope_fused_qknorm" in self.rope_scaling:
assert self.k_norm.variance_epsilon == self.q_norm.variance_epsilon
q, k, v = self.rotary_emb(
qkv,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.num_heads,
self.num_kv_heads,
self.k_norm.variance_epsilon,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state

Expand Down
Loading