[MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel#37815
[MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel#37815varun-sundar-rabindranath wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses NaN issues in the flashinfer MLA kernel by explicitly zeroing out padded regions in the output. It also reverts changes from a previous PR that introduced a pre-allocated output buffer, which was causing out-of-memory errors.
The fix in flashinfer_mla.py to use o[attn_metadata.decode.seq_lens == 0] = 0 is clear and correct.
However, I've identified a potential issue in cutlass_mla.py where the removal of the pre-allocated buffer logic also changed the output buffer initialization from new_zeros to new_empty. This could re-introduce NaN issues for padded requests in the cutlass backend if the kernel doesn't write to those memory locations. My review includes a suggestion to restore the zero-initialization to prevent this.
5308d59 to
845bb31
Compare
| f"output shape {o.size()} != " | ||
| f"seq_lens shape {attn_metadata.decode.seq_lens.size()}" | ||
| ) | ||
| o[attn_metadata.decode.seq_lens == 0] = 0 |
There was a problem hiding this comment.
nit: I think nan_to_num could be faster (single kernel vs 2) or we should precompute attn_metadata.decode.seq_lens == 0
There was a problem hiding this comment.
Quick claude benchmark:
┌───────────────────────────┬─────────────────┬──────────────────┬─────────┐
│ Shape │ nan_to_num (µs) │ mask assign (µs) │ Speedup │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (64, 128, 512, pad=8) │ 5.95 │ 10.64 │ 1.79x │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (128, 128, 512, pad=16) │ 7.77 │ 18.57 │ 2.39x │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (256, 128, 512, pad=32) │ 13.53 │ 34.75 │ 2.57x │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (512, 128, 512, pad=64) │ 35.26 │ 90.36 │ 2.56x │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (1024, 128, 512, pad=128) │ 66.94 │ 176.25 │ 2.63x │
├───────────────────────────┼─────────────────┼──────────────────┼─────────┤
│ (2048, 128, 512, pad=256) │ 166.12 │ 348.30 │ 2.10x │
└───────────────────────────┴─────────────────┴──────────────────┴─────────┘
There was a problem hiding this comment.
Just worried that nan_to_num on the whole tensor might mask future bugs where the kernel itself produces nans. Let me see if there is a better way that masks only the padded region (precomputing the padding mask, like you mentioned).
There was a problem hiding this comment.
focusing on padding <=16 (how the cudagraph sizes are computed) -- I have the following numbers,
┌───────────────────────────┬─────────────────┬──────────────────┬───────────────────┐
│ Shape │ nan_to_num (µs) │ mask assign (µs) │ pre-mask assign (µs) │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (1, 128, 512, pad=0) │ 9.82 │ 36.07 │ 16.60 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (2, 128, 512, pad=1) │ 12.25 │ 38.10 │ 14.92 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (4, 128, 512, pad=1) │ 13.31 │ 37.40 │ 15.16 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (8, 128, 512, pad=2) │ 12.93 │ 37.42 │ 16.47 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (12, 128, 512, pad=4) │ 12.63 │ 38.09 │ 16.44 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (16, 128, 512, pad=8) │ 11.54 │ 35.09 │ 14.38 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (24, 128, 512, pad=12) │ 12.12 │ 36.36 │ 15.24 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (32, 128, 512, pad=16) │ 11.75 │ 34.67 │ 16.12 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (64, 128, 512, pad=16) │ 12.16 │ 35.35 │ 15.55 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (128, 128, 512, pad=16) │ 9.78 │ 34.79 │ 16.19 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (192, 128, 512, pad=16) │ 10.28 │ 29.60 │ 22.55 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (256, 128, 512, pad=16) │ 12.22 │ 32.64 │ 28.67 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (320, 128, 512, pad=16) │ 14.29 │ 39.11 │ 34.82 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (384, 128, 512, pad=16) │ 16.48 │ 46.70 │ 41.01 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (512, 128, 512, pad=16) │ 20.46 │ 59.32 │ 55.25 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (640, 128, 512, pad=16) │ 26.36 │ 75.84 │ 71.70 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (768, 128, 512, pad=16) │ 33.09 │ 106.07 │ 102.25 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (896, 128, 512, pad=16) │ 41.23 │ 131.10 │ 127.05 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
│ (1024, 128, 512, pad=16) │ 47.12 │ 149.54 │ 145.51 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤
precomputing mask helps -- This is the current implementation. however nan_to_num is the clear winner.
I am not using nan_to_num for the reason mentioned above (in #37815 (comment)) ..
what do you guys think ? Thanks.
There was a problem hiding this comment.
I couldn't find a better way to just use torch operations for the desired effect. I added a zero out kernel instead.
the following are the numbers running the kernels inside a cudagraph (so as to not include the triton launch overhead) --
┌───────────────────────────┬─────────────────┬──────────────────┬───────────────────┬───────────────────────┐
│ Shape │ nan_to_num (µs) │ mask assign (µs) │ pre-mask assign (µs) │ triton zero rows (µs) │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (1, 128, 512, pad=0) │ 2.03 │ 3.82 │ 2.85 │ 1.67 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (2, 128, 512, pad=1) │ 2.81 │ 5.04 │ 3.66 │ 4.40 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (4, 128, 512, pad=1) │ 2.96 │ 5.04 │ 3.54 │ 4.23 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (8, 128, 512, pad=2) │ 2.84 │ 5.01 │ 3.69 │ 4.27 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (12, 128, 512, pad=4) │ 2.85 │ 5.16 │ 3.65 │ 4.38 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (16, 128, 512, pad=8) │ 3.01 │ 5.47 │ 4.02 │ 4.66 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (24, 128, 512, pad=12) │ 3.07 │ 6.28 │ 5.14 │ 4.62 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (32, 128, 512, pad=16) │ 3.38 │ 7.06 │ 5.99 │ 4.50 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (64, 128, 512, pad=16) │ 5.08 │ 10.55 │ 9.44 │ 4.64 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (128, 128, 512, pad=16) │ 8.02 │ 21.09 │ 19.78 │ 4.50 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (192, 128, 512, pad=16) │ 10.66 │ 30.18 │ 28.75 │ 4.44 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (256, 128, 512, pad=16) │ 13.41 │ 39.01 │ 37.71 │ 4.31 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (320, 128, 512, pad=16) │ 16.04 │ 47.97 │ 46.62 │ 4.43 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (384, 128, 512, pad=16) │ 18.57 │ 56.96 │ 55.38 │ 4.46 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (512, 128, 512, pad=16) │ 23.82 │ 74.72 │ 73.24 │ 4.19 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (640, 128, 512, pad=16) │ 29.21 │ 92.60 │ 91.06 │ 4.34 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (768, 128, 512, pad=16) │ 34.37 │ 110.67 │ 108.88 │ 4.45 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (896, 128, 512, pad=16) │ 39.54 │ 128.35 │ 126.90 │ 4.45 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
│ (1024, 128, 512, pad=16) │ 45.05 │ 145.93 │ 144.68 │ 4.47 │
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤
4be71db to
39bc2d8
Compare
|
Hi @varun-sundar-rabindranath, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @varun-sundar-rabindranath, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
53eb14b to
fb30501
Compare
| tl.zeros([BLOCK_SIZE], dtype=out_ptr.dtype.element_ty), | ||
| mask=mask, | ||
| ) | ||
| out_ptrs += BLOCK_SIZE |
There was a problem hiding this comment.
@LucasWilkinson @elvircrn the kernel can use a fresh pair of eyes. Thanks 🙌
8cee1cd to
96b2d37
Compare
|
I see a crash after rebase - Testing it now. |
After some multi-node testing locally - I see EPLB hanging and server crashing on PR and on main. |
…N from CUDA graph padding (vllm-project#37442)" This reverts commit ef2c4f7. Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: varun <vrh> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
3c02656 to
ff4ed50
Compare
Purpose
We see flashinfer MLA kernel introducing NaNs in the padded cudagraph region. This PR zeros out the output of the
trtllm_batch_decode_with_kv_cache_mlakernel.Additionally reverts PR #37442 to as we observe some OOMs.
Fixes : #37777
Flashinfer issue : flashinfer-ai/flashinfer#2883
Test Plan
Run
nvidia/DeepSeek-R1-0528-NVFP4-v2in a wide-ep setup and run lm_evals.Test Result
I see that lm_evals pass locally