Skip to content

[MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel#37815

Open
varun-sundar-rabindranath wants to merge 3 commits intovllm-project:mainfrom
neuralmagic:varun/zero-out-padding
Open

[MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel#37815
varun-sundar-rabindranath wants to merge 3 commits intovllm-project:mainfrom
neuralmagic:varun/zero-out-padding

Conversation

@varun-sundar-rabindranath
Copy link
Copy Markdown
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Mar 22, 2026

Purpose

We see flashinfer MLA kernel introducing NaNs in the padded cudagraph region. This PR zeros out the output of the trtllm_batch_decode_with_kv_cache_mla kernel.

Additionally reverts PR #37442 to as we observe some OOMs.
Fixes : #37777

Flashinfer issue : flashinfer-ai/flashinfer#2883

Test Plan

Run nvidia/DeepSeek-R1-0528-NVFP4-v2 in a wide-ep setup and run lm_evals.

Test Result

I see that lm_evals pass locally

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses NaN issues in the flashinfer MLA kernel by explicitly zeroing out padded regions in the output. It also reverts changes from a previous PR that introduced a pre-allocated output buffer, which was causing out-of-memory errors.

The fix in flashinfer_mla.py to use o[attn_metadata.decode.seq_lens == 0] = 0 is clear and correct.

However, I've identified a potential issue in cutlass_mla.py where the removal of the pre-allocated buffer logic also changed the output buffer initialization from new_zeros to new_empty. This could re-introduce NaN issues for padded requests in the cutlass backend if the kernel doesn't write to those memory locations. My review includes a suggestion to restore the zero-initialization to prevent this.

f"output shape {o.size()} != "
f"seq_lens shape {attn_metadata.decode.seq_lens.size()}"
)
o[attn_metadata.decode.seq_lens == 0] = 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think nan_to_num could be faster (single kernel vs 2) or we should precompute attn_metadata.decode.seq_lens == 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick claude benchmark:

  ┌───────────────────────────┬─────────────────┬──────────────────┬─────────┐                                                                                                                                                 
  │           Shape           │ nan_to_num (µs) │ mask assign (µs) │ Speedup │                                                                                                                                                 
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (64, 128, 512, pad=8)     │            5.95 │            10.64 │   1.79x │
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (128, 128, 512, pad=16)   │            7.77 │            18.57 │   2.39x │                                                                                                                                                 
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (256, 128, 512, pad=32)   │           13.53 │            34.75 │   2.57x │                                                                                                                                                 
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (512, 128, 512, pad=64)   │           35.26 │            90.36 │   2.56x │
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (1024, 128, 512, pad=128) │           66.94 │           176.25 │   2.63x │
  ├───────────────────────────┼─────────────────┼──────────────────┼─────────┤                                                                                                                                                 
  │ (2048, 128, 512, pad=256) │          166.12 │           348.30 │   2.10x │
  └───────────────────────────┴─────────────────┴──────────────────┴─────────┘   

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just worried that nan_to_num on the whole tensor might mask future bugs where the kernel itself produces nans. Let me see if there is a better way that masks only the padded region (precomputing the padding mask, like you mentioned).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good info.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

focusing on padding <=16 (how the cudagraph sizes are computed) -- I have the following numbers,

┌───────────────────────────┬─────────────────┬──────────────────┬───────────────────┐                                                                                                                                                                                                                                      
│           Shape           │ nan_to_num (µs) │ mask assign (µs) │ pre-mask assign (µs) │                                                                                                                                                                                                                                   
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (1, 128, 512, pad=0)      │            9.82 │            36.07 │             16.60 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (2, 128, 512, pad=1)      │           12.25 │            38.10 │             14.92 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (4, 128, 512, pad=1)      │           13.31 │            37.40 │             15.16 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (8, 128, 512, pad=2)      │           12.93 │            37.42 │             16.47 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (12, 128, 512, pad=4)     │           12.63 │            38.09 │             16.44 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (16, 128, 512, pad=8)     │           11.54 │            35.09 │             14.38 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (24, 128, 512, pad=12)    │           12.12 │            36.36 │             15.24 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (32, 128, 512, pad=16)    │           11.75 │            34.67 │             16.12 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (64, 128, 512, pad=16)    │           12.16 │            35.35 │             15.55 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (128, 128, 512, pad=16)   │            9.78 │            34.79 │             16.19 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (192, 128, 512, pad=16)   │           10.28 │            29.60 │             22.55 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (256, 128, 512, pad=16)   │           12.22 │            32.64 │             28.67 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (320, 128, 512, pad=16)   │           14.29 │            39.11 │             34.82 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (384, 128, 512, pad=16)   │           16.48 │            46.70 │             41.01 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (512, 128, 512, pad=16)   │           20.46 │            59.32 │             55.25 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (640, 128, 512, pad=16)   │           26.36 │            75.84 │             71.70 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (768, 128, 512, pad=16)   │           33.09 │           106.07 │            102.25 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (896, 128, 512, pad=16)   │           41.23 │           131.10 │            127.05 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤                                                                                                                                                                                                                                      
│ (1024, 128, 512, pad=16)  │           47.12 │           149.54 │            145.51 │                                                                                                                                                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┤ 

precomputing mask helps -- This is the current implementation. however nan_to_num is the clear winner.

I am not using nan_to_num for the reason mentioned above (in #37815 (comment)) ..
what do you guys think ? Thanks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a better way to just use torch operations for the desired effect. I added a zero out kernel instead.
the following are the numbers running the kernels inside a cudagraph (so as to not include the triton launch overhead) --

┌───────────────────────────┬─────────────────┬──────────────────┬───────────────────┬───────────────────────┐                                                                                                      
│           Shape           │ nan_to_num (µs) │ mask assign (µs) │ pre-mask assign (µs) │ triton zero rows (µs) │                                                                                                   
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (1, 128, 512, pad=0)      │            2.03 │             3.82 │              2.85 │                  1.67 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (2, 128, 512, pad=1)      │            2.81 │             5.04 │              3.66 │                  4.40 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (4, 128, 512, pad=1)      │            2.96 │             5.04 │              3.54 │                  4.23 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (8, 128, 512, pad=2)      │            2.84 │             5.01 │              3.69 │                  4.27 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (12, 128, 512, pad=4)     │            2.85 │             5.16 │              3.65 │                  4.38 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (16, 128, 512, pad=8)     │            3.01 │             5.47 │              4.02 │                  4.66 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (24, 128, 512, pad=12)    │            3.07 │             6.28 │              5.14 │                  4.62 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (32, 128, 512, pad=16)    │            3.38 │             7.06 │              5.99 │                  4.50 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (64, 128, 512, pad=16)    │            5.08 │            10.55 │              9.44 │                  4.64 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (128, 128, 512, pad=16)   │            8.02 │            21.09 │             19.78 │                  4.50 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (192, 128, 512, pad=16)   │           10.66 │            30.18 │             28.75 │                  4.44 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (256, 128, 512, pad=16)   │           13.41 │            39.01 │             37.71 │                  4.31 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (320, 128, 512, pad=16)   │           16.04 │            47.97 │             46.62 │                  4.43 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (384, 128, 512, pad=16)   │           18.57 │            56.96 │             55.38 │                  4.46 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (512, 128, 512, pad=16)   │           23.82 │            74.72 │             73.24 │                  4.19 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (640, 128, 512, pad=16)   │           29.21 │            92.60 │             91.06 │                  4.34 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (768, 128, 512, pad=16)   │           34.37 │           110.67 │            108.88 │                  4.45 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (896, 128, 512, pad=16)   │           39.54 │           128.35 │            126.90 │                  4.45 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤                                                                                                      
│ (1024, 128, 512, pad=16)  │           45.05 │           145.93 │            144.68 │                  4.47 │                                                                                                      
├───────────────────────────┼─────────────────┼──────────────────┼───────────────────┼───────────────────────┤ 

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review March 24, 2026 00:45
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 24, 2026

Hi @varun-sundar-rabindranath, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@varun-sundar-rabindranath varun-sundar-rabindranath changed the title Varun/zero out padding [MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel Mar 24, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 24, 2026

Hi @varun-sundar-rabindranath, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

tl.zeros([BLOCK_SIZE], dtype=out_ptr.dtype.element_ty),
mask=mask,
)
out_ptrs += BLOCK_SIZE
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson @elvircrn the kernel can use a fresh pair of eyes. Thanks 🙌

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm!

@varun-sundar-rabindranath
Copy link
Copy Markdown
Contributor Author

I see a crash after rebase - Testing it now.

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as draft March 25, 2026 16:31
@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review March 25, 2026 21:08
@varun-sundar-rabindranath
Copy link
Copy Markdown
Contributor Author

I see a crash after rebase - Testing it now.

After some multi-node testing locally - I see EPLB hanging and server crashing on PR and on main.
I tested that the PR worked with EPLB before the rebase. I believe this should be good to land.

Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 25, 2026
@MatthewBonanni MatthewBonanni added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
Varun Sundar Rabindranath added 2 commits March 27, 2026 14:06
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: varun <vrh>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Ready

Development

Successfully merging this pull request may close these issues.

[Bug]: [OOM] DeepSeek-R1 Out of Memory

4 participants