[DSv4] Improved fused Indexer Q quant kernel#41428
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the Triton-based fused indexer kernel with a new implementation using CUTLASS and CuTe DSLs to handle RoPE and MXFP4 quantization for DeepSeek-V4. Review feedback highlights a critical missing bounds check for the global subwarp ID, an invalid PTX vector size of 8 for 32-bit loads which will cause compilation errors, and a logical indexing error in the RoPE sine cache access.
6a841f3 to
8a8541e
Compare
| num_index_q_heads, | ||
| _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], | ||
| ) | ||
| scale = float(index_weights_softmax_scale * index_weights_head_scale) |
There was a problem hiding this comment.
This is one more kernel launch right? I inclined to do the compute inside the kernel
There was a problem hiding this comment.
index_weights_softmax_scale and index_weights_head_scale are python floats, it will be computed in Python on CPU. also, since we take topk immediately after the logits, i don't even think scaling the weights is necessary 😆
There was a problem hiding this comment.
I think scaling weight is for numeric stability. Just like we scale attention masks.
b0713aa to
3ce9b08
Compare
|
|
||
| from importlib.util import find_spec | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.triton_utils import tl, triton | ||
|
|
||
| HAS_CUTEDSL = find_spec("cutlass") is not None | ||
|
|
There was a problem hiding this comment.
can you follow the pattern in vllm/utils/import_utils.py?
There was a problem hiding this comment.
Done. Can you take a look again? Thank you!
becc675 to
17a4e49
Compare
|
Pending #41603 investigation, since this PR introduces even a bigger change (completely new kernel) |
a86b6b2 to
2b1d126
Compare
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
d28274e to
9d469b8
Compare
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>

Purpose
Replace
_fused_indexer_q_rope_mxfp4_kernelTriton kernel with a CuteDSL version to utilize 256-bit loads. Initially I wrote this in CUDA C++, but couldn't build vLLM from source, so asked Codex to port it over to CuteDSL. Hopefully this will be the first of many CuteDSL kernels to come in vLLM.Update: I keep the original Triton implementation for fallback (potentially for ROCm). Also put CuteDSL kernel in a separate file and add import guards for platform doesn't have CuteDSL.
Microbenchmarks
Benchmark script
Result on GB200
Couldn't quite get to SOL yet (8TB/s), but still should be a good improvement for now.
E2E benchmarks
DSv4-Flash, 4xGB200, 8k-1k, concurrency 256
Test Plan
Test Result
Existing tests pass. GSM8k 0.9484
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.