Skip to content

[Kernel] Tune default fp8 block-scaled Triton config for M<=8 decode#40925

Open
tonyliu312 wants to merge 1 commit intovllm-project:mainfrom
tonyliu312:triton-fp8-decode-config
Open

[Kernel] Tune default fp8 block-scaled Triton config for M<=8 decode#40925
tonyliu312 wants to merge 1 commit intovllm-project:mainfrom
tonyliu312:triton-fp8-decode-config

Conversation

@tonyliu312
Copy link
Copy Markdown

Purpose

w8a8_triton_block_scaled_mm falls back to a hardcoded default config when no pre-tuned configs/N=*,K=*,device_name=*.json file matches the GPU. The default uses BLOCK_SIZE_M=64, which wastes 98% of the M dimension for single-request decode (M=1). GPUs without a pre-tuned JSON for their (N, K) tuple pay this cost.

This PR specializes only the M ≤ 8 branch (single-request decode and short MTP-style draft batches) to use BLOCK_SIZE_M=16, num_stages=3. Larger M is unchanged to keep blast radius small.

M range BLOCK_SIZE_M num_stages
M <= 8 16 3 (new)
else 64 (unchanged) 2 (unchanged)

Pre-tuned JSON configs are unaffected — they short-circuit before this branch. Workloads that already have a tuned JSON for their (N, K, device) keep the exact same kernel as before.

Test Plan

  1. On a host without a pre-tuned JSON for the (N, K, device) of the model under test, run single-request fp8 block-scaled decode at M=1.
  2. Measure median tokens/sec over a steady-state window.
  3. Verify output text is coherent.
  4. Confirm no behavioral change for prefill (large M) or for hosts with a pre-tuned JSON.

Test Result

Hardware: dual DGX Spark (GB10, sm_121), TP=2. Model: DeepSeek-V4-Flash. Mode: single request, decode-dominated.

Before (BLOCK_SIZE_M=64 default) After (M ≤ 8 specialised)
Median decode throughput 5.45 t/s 6.73 t/s
Speedup +23%
Output coherence OK OK

Run config: --moe-backend marlin --kv-cache-dtype fp8_ds_mla --gpu-memory-utilization 0.88 --load-format instanttensor, max_tokens=80, temperature=0, 5 warm runs, statistical median.

No-regression check: for M ≥ 9 the patch falls into the existing else branch with BLOCK_SIZE_M=64, num_stages=2, matching the prior default exactly. For hosts with a tuned JSON, the if configs: branch fires and this code path is never reached.

Notes

  • Scope deliberately narrow: only M <= 8 is changed. M > 8 keeps the previous default. This is intentionally smaller than my first internal draft (which also added an M ≤ 32 branch) to minimise risk on archs where I don't have a way to verify (Hopper / Ampere / Datacenter Blackwell).
  • Cross-arch claim is honest: the win is measured only on GB10 (sm_121). The 98%-of-M-dim waste argument generalises (any decoder with M=1 single-request load lacking a tuned JSON), but I cannot verify on Hopper/Ampere from this hardware. Reviewers on those archs are welcome to push back or confirm. If a regression is found on a specific arch, narrowing to if capability_family == 12 is straightforward.
  • num_stages=3 for the M ≤ 8 branch is the local optimum from a short autotune sweep on GB10. On Hopper, num_stages can interact with shared-mem occupancy. If this regresses, the safe fallback is to keep num_stages=2 in the M ≤ 8 branch (BLOCK_M=16 alone still recovers most of the win).
  • Conservative: I deliberately do not change BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, or num_warps, because changing them has second-order effects that vary by arch.

Related

Checklist

  • Tested on real hardware (GB10 / DGX Spark, sm_121, TP=2)
  • No change for M > 8 (existing default preserved)
  • No change when a tuned JSON config is present (short-circuited above)
  • Commit signed-off (DCO)
  • Pre-commit lint run locally (pre-commit run --files vllm/model_executor/layers/quantization/utils/fp8_utils.py)
  • CI passes (will be confirmed once submitted; pending ready label from maintainer per first-time-contributor gate)

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the w8a8_triton_block_scaled_mm Triton kernel for low-M decode scenarios by adjusting BLOCK_SIZE_M and num_stages when M <= 8. A review comment points out that setting num_stages to 3 can lead to resource exhaustion on ROCm platforms and suggests using a platform-specific configuration to maintain stability.

Comment on lines +729 to +732
if M <= 8:
block_m, num_stages = 16, 3
else:
block_m, num_stages = 64, 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting num_stages=3 can lead to OutOfResources errors or performance regressions on ROCm, as seen in other Triton kernels in the codebase (e.g., in vllm/model_executor/layers/fused_moe/fused_moe.py). It is safer to use num_stages=2 for ROCm while keeping the optimized value for NVIDIA architectures.

Suggested change
if M <= 8:
block_m, num_stages = 16, 3
else:
block_m, num_stages = 64, 2
if M <= 8:
block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3)
else:
block_m, num_stages = 64, 2

@tonyliu312
Copy link
Copy Markdown
Author

Hi @mgoin @tlrmchlsmth — this is my first vLLM contribution, so the pre-run-check workflow is currently blocked on the first-time-contributor ready label gate (no merged PRs yet on this account, so CI never runs the actual checks).

Would one of you (or any quantization CODEOWNER) be willing to take a quick look and add the ready label if it looks reasonable? It's a small change (10 lines added, 2 removed) inside the else branch of w8a8_triton_block_scaled_mm's default-config fallback — only applied when no pre-tuned configs/N=*,K=*,device_name=*.json matches the GPU. The branch is intentionally narrow: only M <= 8 is specialised (BLOCK_SIZE_M=16, num_stages=3); larger M keeps the previous default exactly. Verified +23% on GB10 single-request decode (5.45 → 6.73 t/s) for V4-Flash, with no regression possible for M > 8 or for hosts that already hit a tuned JSON.

Happy to address any review feedback. Thanks for your time!

@tonyliu312 tonyliu312 force-pushed the triton-fp8-decode-config branch from b7ed937 to da10c6e Compare April 26, 2026 16:15
@tonyliu312
Copy link
Copy Markdown
Author

Thanks for catching this @gemini-code-assist — you're right.

w8a8_triton_block_scaled_mm is a platform-agnostic op (registered via torch.ops.vllm.w8a8_triton_block_scaled_mm_func), and 3-stage pipelining at the typical [128, 128] block sizes is borderline for MI300/MI250X LDS (64 KB).

Just amended the commit (force-pushed da10c6e) to gate num_stages=3 to non-ROCm:

if M <= 8:
    block_m = 16
    num_stages = 2 if current_platform.is_rocm() else 3
else:
    block_m, num_stages = 64, 2

Rationale: the BLOCK_SIZE_M=16 wave-quantisation win is the dominant effect for low-M decode (most of the +23% on GB10). num_stages=3 is the smaller second-order win specific to CUDA. On ROCm we keep num_stages=2 so the M≤8 branch still recovers the bulk of the gain without LDS pressure. Commit message updated accordingly.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Thanks for the update, @tonyliu312. The change to gate num_stages=3 to non-ROCm platforms is a prudent adjustment given the LDS constraints on MI300/MI250X. This approach maintains the performance gains for the M <= 8 case while ensuring stability across different hardware architectures. The logic looks sound.

`w8a8_triton_block_scaled_mm` falls back to a hardcoded default config
when no pre-tuned `configs/N=*,K=*,device_name=*.json` file matches the
GPU. The default uses `BLOCK_SIZE_M=64`, which wastes 98% of the M
dimension in single-request decode (M=1). GPUs without a pre-tuned JSON
file for their (N, K, device) tuple pay this cost.

Narrow the change: only specialize the M<=8 case (single-request decode
and short MTP-style draft batches). Larger M keeps the previous default
unchanged so non-decode paths and tuned configs are not perturbed.

  M <= 8 (CUDA)   -> BLOCK_SIZE_M=16, num_stages=3   (new)
  M <= 8 (ROCm)   -> BLOCK_SIZE_M=16, num_stages=2   (new)
  else            -> BLOCK_SIZE_M=64, num_stages=2   (previous default)

num_stages=3 is gated to non-ROCm because MI300/MI250X LDS (64 KB) is
borderline for 3-stage Triton pipelining at typical [128, 128] block
sizes; on ROCm we keep num_stages=2 so the M<=8 branch still gets the
BLOCK_SIZE_M=16 wave-quantisation win without LDS pressure.

Pre-tuned JSON configs are unaffected (they short-circuit before this
branch). Workloads that already have a JSON for their (N, K, device)
get the same kernel as before.

Verified on dual DGX Spark (GB10, sm_121, TP=2) running V4-Flash:
median single-request decode goes from 5.45 t/s to 6.73 t/s (+23%) with
no other changes. Output remains coherent. The win is expected to
generalize to other architectures lacking a pre-tuned JSON for the
target (N, K) pair, but only the GB10 case is verified here; reviewers
on Hopper/Ampere are welcome to confirm or push back.

Refs vllm-project#40860 (V4 rebase), vllm-project#40899 (jasl SM12x scope is orthogonal)

Signed-off-by: Tony Liu <tonyliu0512@gmail.com>
@tonyliu312 tonyliu312 force-pushed the triton-fp8-decode-config branch from da10c6e to c76758e Compare April 27, 2026 01:36
@tonyliu312 tonyliu312 mentioned this pull request Apr 27, 2026
4 tasks
wyjBot added a commit to wyjBot/vllm that referenced this pull request Apr 29, 2026
…V4-Flash expert shapes

Adds Triton autotuned configs for NVIDIA H20-3e (SM90) for the two
primary expert GEMM shapes in DeepSeek-V4-Flash:

  N=2048, K=7168  (gate / up projection, per expert)
  N=7168, K=2048  (down projection, per expert)

Without these configs the fallback uses BLOCK_SIZE_M=64 regardless of
actual batch size.  For decode batches where M=1–32 this wastes ~60 % of
Tensor Core capacity through excessive CTA padding.

Autotune results on NVIDIA H20-3e (SM90, CUDA 9.0):

  Shape N=2048,K=7168 — speedup vs fallback (BM=64,BN=128,BK=128):
    M=1  : 1.79×   M=4  : 1.64×   M=8  : 1.62×
    M=16 : 1.63×   M=32 : 1.57×   M=48+: ~1.0× (fallback already optimal)

  Shape N=7168,K=2048:
    M=1  : 1.60×   M=4  : 1.23×   M=8  : 1.17×   M=16 : 1.16×

Key finding: BLOCK_SIZE_M=16 (instead of 64) is optimal for M≤32 on H20
because it avoids ~75% CTA padding that otherwise kills Tensor Core utilisation.

Also adds N=2048,K=7168 and N=7168,K=2048 to the benchmark script's
get_weight_shapes() so these shapes are included in future re-tuning runs.

Related: PR vllm-project#40925 (same fix for GB10/SM120 — "Hopper reviewers welcome").
Made-with: Cursor
wyjBot added a commit to wyjBot/vllm that referenced this pull request Apr 29, 2026
…V4-Flash expert shapes

Adds Triton autotuned configs for NVIDIA H20-3e (SM90) for the two
primary expert GEMM shapes in DeepSeek-V4-Flash:

  N=2048, K=7168  (gate / up projection, per expert)
  N=7168, K=2048  (down projection, per expert)

Without these configs the fallback uses BLOCK_SIZE_M=64 regardless of
actual batch size.  For decode batches where M=1–32 this wastes ~60 % of
Tensor Core capacity through excessive CTA padding.

Autotune results on NVIDIA H20-3e (SM90, CUDA 9.0):

  Shape N=2048,K=7168 — speedup vs fallback (BM=64,BN=128,BK=128):
    M=1  : 1.79×   M=4  : 1.64×   M=8  : 1.62×
    M=16 : 1.63×   M=32 : 1.57×   M=48+: ~1.0× (fallback already optimal)

  Shape N=7168,K=2048:
    M=1  : 1.60×   M=4  : 1.23×   M=8  : 1.17×   M=16 : 1.16×

Key finding: BLOCK_SIZE_M=16 (instead of 64) is optimal for M≤32 on H20
because it avoids ~75% CTA padding that otherwise kills Tensor Core utilisation.

Also adds N=2048,K=7168 and N=7168,K=2048 to the benchmark script's
get_weight_shapes() so these shapes are included in future re-tuning runs.

Related: PR vllm-project#40925 (same fix for GB10/SM120 — "Hopper reviewers welcome").
Made-with: Cursor
wyjBot added a commit to wyjBot/vllm that referenced this pull request Apr 29, 2026
…V4-Flash expert shapes

Adds Triton autotuned configs for NVIDIA H20-3e (SM90) for the two
primary expert GEMM shapes in DeepSeek-V4-Flash:

  N=2048, K=7168  (gate / up projection, per expert)
  N=7168, K=2048  (down projection, per expert)

Without these configs the fallback uses BLOCK_SIZE_M=64 regardless of
actual batch size.  For decode batches where M=1–32 this wastes ~60 % of
Tensor Core capacity through excessive CTA padding.

Autotune results on NVIDIA H20-3e (SM90, CUDA 9.0):

  Shape N=2048,K=7168 — speedup vs fallback (BM=64,BN=128,BK=128):
    M=1  : 1.79×   M=4  : 1.64×   M=8  : 1.62×
    M=16 : 1.63×   M=32 : 1.57×   M=48+: ~1.0× (fallback already optimal)

  Shape N=7168,K=2048:
    M=1  : 1.60×   M=4  : 1.23×   M=8  : 1.17×   M=16 : 1.16×

Key finding: BLOCK_SIZE_M=16 (instead of 64) is optimal for M≤32 on H20
because it avoids ~75% CTA padding that otherwise kills Tensor Core utilisation.

Also adds N=2048,K=7168 and N=7168,K=2048 to the benchmark script's
get_weight_shapes() so these shapes are included in future re-tuning runs.

Related: PR vllm-project#40925 (same fix for GB10/SM120 — "Hopper reviewers welcome").
Made-with: Cursor
Signed-off-by: wyjBot <fkeryj@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant