Skip to content

[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5#37338

Merged
ywang96 merged 3 commits intovllm-project:mainfrom
arpera:artem/qwen35-fix-autotuning-in-inference
Mar 23, 2026
Merged

[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5#37338
ywang96 merged 3 commits intovllm-project:mainfrom
arpera:artem/qwen35-fix-autotuning-in-inference

Conversation

@arpera
Copy link
Contributor

@arpera arpera commented Mar 17, 2026

Purpose

Bugfix for PR #36599 + type annotation cleanup: the GDN Triton warmup introduced in that PR uses dummy tensors whose dtypes don't match real inference, so Triton autotuner cache entries created during warmup are never reused — autotuning re-runs on the first inference batch. Additionally, cu_seqlens type annotations across FLA ops are corrected from torch.LongTensor to torch.Tensor to reflect the actual int32 dtype used at runtime.

Problem

When serving Qwen/Qwen3.5-397B-A17B-FP8 with -dp 8 --enable-expert-parallel, the first inference batch triggers 746 Triton autotuning events across 6 FLA kernels (chunk_fwd_kernel_o, chunk_gated_delta_rule_fwd_kernel_h_blockdim64, chunk_local_cumsum_scalar_kernel, chunk_scaled_dot_kkt_fwd_kernel, merge_16x16_to_64x64_inverse_kernel, recompute_w_u_fwd_kernel). This happens because _warmup_prefill_kernels() (added in #36599) produces tensors with different dtypes than the real inference path, causing Triton cache key mismatches at serve time.

Root cause

Three dtype/shape mismatches in qwen3_next.py warmup (lines 707–750):

Tensor Warmup (before) Real inference Fix
g (gate) torch.randn(..., dtype=bfloat16) fused_gdn_gating() output → float32 Call fused_gdn_gating() during warmup
cu_seqlens dtype=torch.long (int64) dtype=torch.int32 Use torch.int32
output_final_state False True Set to True

Each mismatch produces a distinct Triton kernel specialization key, so autotuning reruns at inference.

Changes

  1. vllm/model_executor/models/qwen3_next.py — produce warmup tensors matching inference dtypes:

    • Generate g, beta via fused_gdn_gating() instead of torch.randn (float32 vs bfloat16)
    • Use dtype=torch.int32 for cu_seqlens (matching inference path)
    • Set output_final_state=True (matching inference path)
  2. vllm/model_executor/layers/fla/ops/*.py (10 files) — cleanup: change type annotations from torch.LongTensor to torch.Tensor for cu_seqlens parameters. These annotations were inaccurate (cu_seqlens is actually int32 at runtime), and while fixing them has no performance impact, it keeps the code consistent with the dtype change above.

Test Plan

A/B experiment on 8× NVIDIA B200 (183 GiB each) with TRITON_PRINT_AUTOTUNING=1, clean Triton cache (rm -rf ~/.triton/cache/*) before each run:

# Server command (both experiments)
TRITON_PRINT_AUTOTUNING=1 vllm serve "Qwen/Qwen3.5-397B-A17B-FP8" \
    --port 8000 -tp 1 -pp 1 -dp 8 --enable-expert-parallel \
    --language-model-only --reasoning-parser qwen3 \
    --stream-interval 100 --safetensors-load-strategy prefetch

# Benchmark command (both experiments)
vllm bench serve --backend vllm --model Qwen/Qwen3.5-397B-A17B-FP8 \
    --port 8000 --endpoint /v1/completions --dataset-name random \
    --random-input 8000 --random-output 1 --max-concurrency 8 \
    --num-prompt 128 --ignore-eos --temperature 0.0
  • Experiment A: baseline — clear Triton cache, start server, run benchmark
  • Experiment B: baseline + this patch — clear Triton cache, start server, run benchmark
  • Metric: count Autotuning lines in server logs before and after Application startup complete

Test Result

Autotuning distribution

Phase A (without patch) B (with patch)
During server startup 1,162 1,162
During benchmark 746 0
Total 1,908 1,162

With the patch, all Triton autotuning completes during server warmup. Zero autotuning events leak into inference.

Example: full autotuning cycle for one kernel from Experiment A, happening during live inference:

Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Triton autotuning for function chunk_local_cumsum_scalar_kernel,
with key as (1, 64, 64, True, False, 'torch.float32', 'torch.float32', 'torch.int32', 'torch.int32'),
finished after 1.87s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;
...

Benchmark performance (128 prompts, 8K input tokens each)

Metric A (without patch) B (with patch) Improvement
Duration 109.73 s 27.34 s 4.0× faster
Total token throughput 9,333 tok/s 37,459 tok/s 4.0× higher
Mean TTFT 6,821 ms 1,670 ms 4.1× lower
Median TTFT 929 ms 999 ms ~parity
Failed requests 0 0

The dramatic improvement in mean TTFT and throughput is due to eliminating autotuning stalls during inference. Median TTFT is at parity because median captures the typical (non-stalled) request latency.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Made-with: Cursor
@arpera arpera requested a review from sighingnow as a code owner March 17, 2026 19:03
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a crucial performance fix for Triton autotuning in Qwen3.5 models by aligning the tensor properties used during kernel warmup with those in the actual inference path. The changes correctly address dtype mismatches for the g gate tensor and cu_seqlens, and align the output_final_state parameter, which resolves the autotuner cache misses and leads to significant performance gains as demonstrated by the test results. The accompanying cleanup of cu_seqlens type annotations across multiple files improves code correctness and consistency. The implementation appears solid and well-tested.

@vadiklyutiy
Copy link
Collaborator

cc @AuYang261 @ZJY0516

@vadiklyutiy vadiklyutiy requested a review from ywang96 March 17, 2026 21:11
@mergify mergify bot added qwen Related to Qwen models bug Something isn't working labels Mar 17, 2026
@AuYang261
Copy link
Contributor

LGTM.

Copy link
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this. LGTM

Copy link
Collaborator

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except adding comment, looks good

@vadiklyutiy vadiklyutiy added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
arpera added 2 commits March 18, 2026 13:52
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Made-with: Cursor
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Made-with: Cursor
@arpera
Copy link
Contributor Author

arpera commented Mar 23, 2026

@sighingnow @ywang96 review this PR please

@ywang96 ywang96 merged commit a16133a into vllm-project:main Mar 23, 2026
56 checks passed
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Mar 23, 2026
…project#37338)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants