diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 4f018ea52d2f..f424052d1430 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding.mrope import MRotaryEmbedding from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -30,13 +31,25 @@ from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.utils import apply_qk_norm from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda, is_npu +from sglang.srt.utils import add_prefix, get_bool_env_var, is_cuda, is_hip, is_npu Qwen3Config = None logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_hip = is_hip() _is_npu = is_npu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +_has_fused_qk_norm_mrope = False +if _use_aiter: + try: + from aiter import fused_qk_norm_mrope_3d_cache_pts_quant_shuffle + + _has_fused_qk_norm_mrope = True + logger.info("aiter fused_qk_norm_mrope_3d kernel available") + except ImportError: + pass if _is_npu: from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope @@ -138,6 +151,19 @@ def __init__( ) self.alt_stream = alt_stream + self.use_fused_qk_norm_mrope = ( + _has_fused_qk_norm_mrope + and isinstance(self.rotary_emb, MRotaryEmbedding) + and getattr(self.rotary_emb, "mrope_section", None) is not None + ) + if self.use_fused_qk_norm_mrope: + # Scale tensors MUST stay on CPU: the C++ kernel uses .item() + # which triggers hipMemcpy D2H + sync on CUDA tensors, breaking graph capture. + # Explicit device='cpu' is required because SGLang constructs models inside + # a `with torch.device('cuda'):` context that changes the default device. + self._fused_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu") + self._fused_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu") + def forward_prepare_native(self, positions, hidden_states): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -172,6 +198,66 @@ def forward_prepare_npu(self, positions, hidden_states, forward_batch): ) return q, k, v + def forward_prepare_aiter_fused_mrope(self, positions, hidden_states, forward_batch): + """Fused QK-norm + 3D mRoPE + KV cache write for decode (ROCm/aiter). + + The fused HIP kernel replaces split → QK norm → mRoPE → cache write, + so KV is already in the paged cache when this returns. + Returns (q, None, None); caller must pass save_kv_cache=False to attn. + """ + qkv, _ = self.qkv_proj(hidden_states) + num_tokens = qkv.shape[0] + + qkv_3d = qkv.view(num_tokens, -1, self.head_dim) + + token_to_kv_pool = forward_batch.token_to_kv_pool + k_cache, v_cache = token_to_kv_pool.get_kv_buffer(self.attn.layer_id) + slot_mapping = forward_batch.out_cache_loc + + cos_sin = self.rotary_emb.cos_sin_cache + if cos_sin.dtype != qkv.dtype: + cos_sin = cos_sin.to(dtype=qkv.dtype) + + q_out = torch.empty( + num_tokens, + self.num_heads, + self.head_dim, + dtype=qkv.dtype, + device=qkv.device, + ) + + fused_qk_norm_mrope_3d_cache_pts_quant_shuffle( + qkv_3d, + self.q_norm.weight, + self.k_norm.weight, + cos_sin, + positions, + num_tokens, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.rotary_emb.is_neox_style, + self.rotary_emb.mrope_section, + self.rotary_emb.mrope_interleaved, + self.q_norm.variance_epsilon, + q_out, + k_cache, + v_cache, + slot_mapping, + self._fused_k_scale, + self._fused_v_scale, + None, + None, + False, + False, + 0, + 0, + ) + + q = q_out.reshape(num_tokens, -1) + return q, None, None + def forward( self, positions: torch.Tensor, @@ -181,7 +267,19 @@ def forward( if get_global_server_args().rl_on_policy_target is not None: hidden_states = hidden_states.bfloat16() - if ( + save_kv_cache = True + use_aiter_fused = ( + self.use_fused_qk_norm_mrope + and forward_batch.forward_mode.is_decode() + and get_global_server_args().rl_on_policy_target is None + ) + + if use_aiter_fused: + q, k, v = self.forward_prepare_aiter_fused_mrope( + positions, hidden_states, forward_batch + ) + save_kv_cache = False + elif ( not _is_npu or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed() ): @@ -200,7 +298,7 @@ def forward( q = q.to(torch.bfloat16) k = k.to(torch.bfloat16) - attn_output = self.attn(q, k, v, forward_batch) + attn_output = self.attn(q, k, v, forward_batch, save_kv_cache=save_kv_cache) output, _ = self.o_proj(attn_output) return output