[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5#37338
Conversation
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com> Made-with: Cursor
There was a problem hiding this comment.
Code Review
This pull request provides a crucial performance fix for Triton autotuning in Qwen3.5 models by aligning the tensor properties used during kernel warmup with those in the actual inference path. The changes correctly address dtype mismatches for the g gate tensor and cu_seqlens, and align the output_final_state parameter, which resolves the autotuner cache misses and leads to significant performance gains as demonstrated by the test results. The accompanying cleanup of cu_seqlens type annotations across multiple files improves code correctness and consistency. The implementation appears solid and well-tested.
|
LGTM. |
ZJY0516
left a comment
There was a problem hiding this comment.
Thanks for addressing this. LGTM
vadiklyutiy
left a comment
There was a problem hiding this comment.
except adding comment, looks good
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com> Made-with: Cursor
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com> Made-with: Cursor
|
@sighingnow @ywang96 review this PR please |
…project#37338) Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Purpose
Bugfix for PR #36599 + type annotation cleanup: the GDN Triton warmup introduced in that PR uses dummy tensors whose dtypes don't match real inference, so Triton autotuner cache entries created during warmup are never reused — autotuning re-runs on the first inference batch. Additionally,
cu_seqlenstype annotations across FLA ops are corrected fromtorch.LongTensortotorch.Tensorto reflect the actualint32dtype used at runtime.Problem
When serving
Qwen/Qwen3.5-397B-A17B-FP8with-dp 8 --enable-expert-parallel, the first inference batch triggers 746 Triton autotuning events across 6 FLA kernels (chunk_fwd_kernel_o,chunk_gated_delta_rule_fwd_kernel_h_blockdim64,chunk_local_cumsum_scalar_kernel,chunk_scaled_dot_kkt_fwd_kernel,merge_16x16_to_64x64_inverse_kernel,recompute_w_u_fwd_kernel). This happens because_warmup_prefill_kernels()(added in #36599) produces tensors with different dtypes than the real inference path, causing Triton cache key mismatches at serve time.Root cause
Three dtype/shape mismatches in
qwen3_next.pywarmup (lines 707–750):g(gate)torch.randn(..., dtype=bfloat16)fused_gdn_gating()output →float32fused_gdn_gating()during warmupcu_seqlensdtype=torch.long(int64)dtype=torch.int32torch.int32output_final_stateFalseTrueTrueEach mismatch produces a distinct Triton kernel specialization key, so autotuning reruns at inference.
Changes
vllm/model_executor/models/qwen3_next.py— produce warmup tensors matching inference dtypes:g, betaviafused_gdn_gating()instead oftorch.randn(float32 vs bfloat16)dtype=torch.int32forcu_seqlens(matching inference path)output_final_state=True(matching inference path)vllm/model_executor/layers/fla/ops/*.py(10 files) — cleanup: change type annotations fromtorch.LongTensortotorch.Tensorforcu_seqlensparameters. These annotations were inaccurate (cu_seqlensis actuallyint32at runtime), and while fixing them has no performance impact, it keeps the code consistent with the dtype change above.Test Plan
A/B experiment on 8× NVIDIA B200 (183 GiB each) with
TRITON_PRINT_AUTOTUNING=1, clean Triton cache (rm -rf ~/.triton/cache/*) before each run:Autotuninglines in server logs before and afterApplication startup completeTest Result
Autotuning distribution
With the patch, all Triton autotuning completes during server warmup. Zero autotuning events leak into inference.
Example: full autotuning cycle for one kernel from Experiment A, happening during live inference:
Benchmark performance (128 prompts, 8K input tokens each)
The dramatic improvement in mean TTFT and throughput is due to eliminating autotuning stalls during inference. Median TTFT is at parity because median captures the typical (non-stalled) request latency.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.