Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 101 additions & 3 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.rotary_embedding.mrope import MRotaryEmbedding
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
Expand All @@ -30,13 +31,25 @@
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.utils import apply_qk_norm
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_npu
from sglang.srt.utils import add_prefix, get_bool_env_var, is_cuda, is_hip, is_npu

Qwen3Config = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

_has_fused_qk_norm_mrope = False
if _use_aiter:
try:
from aiter import fused_qk_norm_mrope_3d_cache_pts_quant_shuffle

_has_fused_qk_norm_mrope = True
logger.info("aiter fused_qk_norm_mrope_3d kernel available")
except ImportError:
pass

if _is_npu:
from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope
Expand Down Expand Up @@ -138,6 +151,19 @@ def __init__(
)
self.alt_stream = alt_stream

self.use_fused_qk_norm_mrope = (
_has_fused_qk_norm_mrope
and isinstance(self.rotary_emb, MRotaryEmbedding)
and getattr(self.rotary_emb, "mrope_section", None) is not None
)
if self.use_fused_qk_norm_mrope:
# Scale tensors MUST stay on CPU: the C++ kernel uses .item<float>()
# which triggers hipMemcpy D2H + sync on CUDA tensors, breaking graph capture.
# Explicit device='cpu' is required because SGLang constructs models inside
# a `with torch.device('cuda'):` context that changes the default device.
self._fused_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu")
self._fused_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu")

def forward_prepare_native(self, positions, hidden_states):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand Down Expand Up @@ -172,6 +198,66 @@ def forward_prepare_npu(self, positions, hidden_states, forward_batch):
)
return q, k, v

def forward_prepare_aiter_fused_mrope(self, positions, hidden_states, forward_batch):
"""Fused QK-norm + 3D mRoPE + KV cache write for decode (ROCm/aiter).

The fused HIP kernel replaces split → QK norm → mRoPE → cache write,
so KV is already in the paged cache when this returns.
Returns (q, None, None); caller must pass save_kv_cache=False to attn.
"""
qkv, _ = self.qkv_proj(hidden_states)
num_tokens = qkv.shape[0]

qkv_3d = qkv.view(num_tokens, -1, self.head_dim)

token_to_kv_pool = forward_batch.token_to_kv_pool
k_cache, v_cache = token_to_kv_pool.get_kv_buffer(self.attn.layer_id)
slot_mapping = forward_batch.out_cache_loc

cos_sin = self.rotary_emb.cos_sin_cache
if cos_sin.dtype != qkv.dtype:
cos_sin = cos_sin.to(dtype=qkv.dtype)

q_out = torch.empty(
num_tokens,
self.num_heads,
self.head_dim,
dtype=qkv.dtype,
device=qkv.device,
)

fused_qk_norm_mrope_3d_cache_pts_quant_shuffle(
qkv_3d,
self.q_norm.weight,
self.k_norm.weight,
cos_sin,
positions,
num_tokens,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.rotary_emb.is_neox_style,
self.rotary_emb.mrope_section,
self.rotary_emb.mrope_interleaved,
self.q_norm.variance_epsilon,
q_out,
k_cache,
v_cache,
slot_mapping,
self._fused_k_scale,
self._fused_v_scale,
None,
None,
False,
False,
0,
0,
)

q = q_out.reshape(num_tokens, -1)
return q, None, None

def forward(
self,
positions: torch.Tensor,
Expand All @@ -181,7 +267,19 @@ def forward(
if get_global_server_args().rl_on_policy_target is not None:
hidden_states = hidden_states.bfloat16()

if (
save_kv_cache = True
use_aiter_fused = (
self.use_fused_qk_norm_mrope
and forward_batch.forward_mode.is_decode()
and get_global_server_args().rl_on_policy_target is None
)

if use_aiter_fused:
q, k, v = self.forward_prepare_aiter_fused_mrope(
positions, hidden_states, forward_batch
)
save_kv_cache = False
elif (
not _is_npu
or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed()
):
Expand All @@ -200,7 +298,7 @@ def forward(
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard in Line:274
is needed because there's a downstream k.to(torch.bfloat16) cast in the RL on-policy path, without the guard, the fused prepare would return k=None and that .to() call would crash.


attn_output = self.attn(q, k, v, forward_batch)
attn_output = self.attn(q, k, v, forward_batch, save_kv_cache=save_kv_cache)
output, _ = self.o_proj(attn_output)
return output

Expand Down
Loading