Skip to content

[DSv4] Improved fused Indexer Q quant kernel#41428

Merged
WoosukKwon merged 21 commits into
vllm-project:mainfrom
gau-nernst:dsv4_indexer_q_mxfp4
May 9, 2026
Merged

[DSv4] Improved fused Indexer Q quant kernel#41428
WoosukKwon merged 21 commits into
vllm-project:mainfrom
gau-nernst:dsv4_indexer_q_mxfp4

Conversation

@gau-nernst

@gau-nernst gau-nernst commented Apr 30, 2026

Copy link
Copy Markdown
Contributor

Purpose

Replace _fused_indexer_q_rope_mxfp4_kernel Triton kernel with a CuteDSL version to utilize 256-bit loads. Initially I wrote this in CUDA C++, but couldn't build vLLM from source, so asked Codex to port it over to CuteDSL. Hopefully this will be the first of many CuteDSL kernels to come in vLLM.

Update: I keep the original Triton implementation for fallback (potentially for ROCm). Also put CuteDSL kernel in a separate file and add import guards for platform doesn't have CuteDSL.

Microbenchmarks

Benchmark script
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
    fused_indexer_q_rope_quant,
)
from flashinfer.testing import bench_gpu_time_with_cupti
import statistics

NUM_HEADS = 64
HEAD_DIM = 128
ROPE_DIM = 64
MAX_POS = 100_000
TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
ROPE_DTYPE = torch.float32


def make_inputs(num_tokens: int):
    positions = torch.randint(MAX_POS, (num_tokens,), dtype=torch.int64)
    query = torch.randn(num_tokens, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16)
    cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=ROPE_DTYPE)
    weights = torch.randn(num_tokens, NUM_HEADS, dtype=torch.bfloat16)
    return (
        positions,
        query,
        cos_sin_cache,
        weights,
        HEAD_DIM**-0.5,
        NUM_HEADS**-0.5,
        True,
    )


def benchmark(num_tokens: int):
    torch.set_default_device("cuda")

    kernel_args = make_inputs(num_tokens)
    timings = bench_gpu_time_with_cupti(lambda: fused_indexer_q_rope_quant(*kernel_args))
    median_ms = statistics.median(timings)

    bytes_per_token = 8  # position int64
    bytes_per_token += NUM_HEADS * HEAD_DIM * 2  # q in bf16
    bytes_per_token += ROPE_DIM * torch.empty((), dtype=ROPE_DTYPE).element_size()
    bytes_per_token += NUM_HEADS * 2  # weights in bf16
    bytes_per_token += NUM_HEADS * HEAD_DIM // 2  # q out fp4
    bytes_per_token += NUM_HEADS * HEAD_DIM // 32  # q_scale uint8
    bytes_per_token += NUM_HEADS * 4  # weights out fp32
    total_bytes = bytes_per_token * num_tokens

    return median_ms, total_bytes


if __name__ == "__main__":
    for num_tokens in TOKENS:
        median_ms, moved_bytes = benchmark(num_tokens)
        bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9
        print(
            f"T={num_tokens:6d}  "
            f"{median_ms * 1e3:7.2f} us  "
            f"BW {bandwidth_gb_s:7.1f} GB/s  "
        )

Result on GB200

Couldn't quite get to SOL yet (8TB/s), but still should be a good improvement for now.

T (tokens) Before (us) After (us) Before BW (GB/s) After BW (GB/s) Speedup
1 3.71 2.69 5.7 8.0 1.38×
8 3.90 2.50 43.6 68.5 1.56×
32 4.35 3.01 156.3 227.5 1.44×
128 8.48 3.46 320.8 792.0 2.45×
256 13.41 4.00 405.8 1368.6 3.35×
512 22.37 5.66 486.5 1933.0 3.95×
1024 40.06 7.78 543.3 2816.0 5.15×
2048 75.71 12.83 575.0 3412.9 5.90×
4096 147.58 21.50 589.9 4073.1 6.86×
8192 290.43 36.06 599.6 4857.4 8.05×
16384 576.05 63.14 604.6 5549.2 9.12×

E2E benchmarks

DSv4-Flash, 4xGB200, 8k-1k, concurrency 256

Setting Throughput (tok/s) TTFT (P50) TPOT (P50)
Before c7aa186 66877 1.39s 32.62 ms
After 67910 1.35s 32.55ms

Test Plan

pytest tests/kernels/test_fused_indexer_q_rope_quant.py

Test Result

Existing tests pass. GSM8k 0.9484


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@mergify mergify Bot added the v1 label May 1, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the Triton-based fused indexer kernel with a new implementation using CUTLASS and CuTe DSLs to handle RoPE and MXFP4 quantization for DeepSeek-V4. Review feedback highlights a critical missing bounds check for the global subwarp ID, an invalid PTX vector size of 8 for 32-bit loads which will cause compilation errors, and a logical indexing error in the RoPE sine cache access.

Comment thread vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py Outdated
Comment thread vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py Outdated
Comment thread vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py Outdated
@gau-nernst gau-nernst force-pushed the dsv4_indexer_q_mxfp4 branch 2 times, most recently from 6a841f3 to 8a8541e Compare May 1, 2026 00:32
@gau-nernst gau-nernst marked this pull request as ready for review May 1, 2026 03:59

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

num_index_q_heads,
_TORCH_TO_CUTE[index_q_cos_sin_cache.dtype],
)
scale = float(index_weights_softmax_scale * index_weights_head_scale)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one more kernel launch right? I inclined to do the compute inside the kernel

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_weights_softmax_scale and index_weights_head_scale are python floats, it will be computed in Python on CPU. also, since we take topk immediately after the logits, i don't even think scaling the weights is necessary 😆

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think scaling weight is for numeric stability. Just like we scale attention masks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Indexer is just (relu(Q @ K.T) * w).sum(). no softmax etc... so i don't think having a scalar multiplication will change any topk ordering.

@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label May 1, 2026
@gau-nernst gau-nernst force-pushed the dsv4_indexer_q_mxfp4 branch from b0713aa to 3ce9b08 Compare May 1, 2026 08:15
Comment on lines +3 to +11

from importlib.util import find_spec

import torch

from vllm.triton_utils import tl, triton

HAS_CUTEDSL = find_spec("cutlass") is not None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you follow the pattern in vllm/utils/import_utils.py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Can you take a look again? Thank you!

@gau-nernst gau-nernst force-pushed the dsv4_indexer_q_mxfp4 branch from becc675 to 17a4e49 Compare May 6, 2026 03:49
@gau-nernst

Copy link
Copy Markdown
Contributor Author

Pending #41603 investigation, since this PR introduces even a bigger change (completely new kernel)

gau-nernst added 12 commits May 9, 2026 01:56
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
gau-nernst added 9 commits May 9, 2026 01:56
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
@gau-nernst gau-nernst force-pushed the dsv4_indexer_q_mxfp4 branch from d28274e to 9d469b8 Compare May 9, 2026 01:58

@WoosukKwon WoosukKwon left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm. Amazing!!

@WoosukKwon WoosukKwon merged commit 530d371 into vllm-project:main May 9, 2026
63 of 67 checks passed
@gau-nernst gau-nernst deleted the dsv4_indexer_q_mxfp4 branch May 9, 2026 08:22
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants