-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLAAttention] Clear Cudagraph padded region of FI decode Attention kernel #37815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
varun-sundar-rabindranath
wants to merge
3
commits into
vllm-project:main
from
neuralmagic:varun/zero-out-padding
+140
−58
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
tests/kernels/attention/test_mla_zero_out_decode_padding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.v1.attention.backends.mla.utils import zero_out_decode_padding | ||
|
|
||
|
|
||
| def _assert_zero_out_matches_ref( | ||
| *, | ||
| num_tokens: int, | ||
| num_cols: int, | ||
| pad_positions: tuple[int, ...], | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| num_heads: int = 3, | ||
| ) -> None: | ||
| out = torch.randn(num_tokens, 3, num_cols, dtype=dtype, device="cuda") | ||
| seq_lens = torch.ones(num_tokens, dtype=torch.int32, device="cuda") | ||
| if pad_positions: | ||
| pad_indices = torch.tensor(pad_positions, dtype=torch.long, device="cuda") | ||
| seq_lens[pad_indices] = 0 | ||
| # Match production behavior: padded rows may contain NaNs. | ||
| out[pad_indices] = torch.nan | ||
|
|
||
| ref = out.clone() | ||
| ref[seq_lens == 0] = 0 | ||
|
|
||
| zero_out_decode_padding(out, seq_lens) | ||
| torch.testing.assert_close(out, ref, atol=0, rtol=0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [1, 2, 3, 8]) | ||
| @pytest.mark.parametrize("num_cols", [257, 1024, 1500]) | ||
| def test_zero_out_padding_exhaustive(num_tokens: int, num_cols: int): | ||
| if num_tokens == 1: | ||
| _assert_zero_out_matches_ref( | ||
| num_tokens=1, | ||
| num_cols=num_cols, | ||
| pad_positions=(), | ||
| ) | ||
| return | ||
|
|
||
| for pad_start in range(1, num_tokens): | ||
| _assert_zero_out_matches_ref( | ||
| num_tokens=num_tokens, | ||
| num_cols=num_cols, | ||
| pad_positions=tuple(list(range(pad_start, num_tokens))), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [4, 5, 10, 13, 25] + list(range(55, 64))) | ||
| @pytest.mark.parametrize("num_cols", [257]) | ||
| def test_zero_out_padding(num_tokens: int, num_cols: int) -> None: | ||
| _assert_zero_out_matches_ref( | ||
| num_tokens=num_tokens, | ||
| num_cols=num_cols, | ||
| pad_positions=(num_tokens - 2, num_tokens - 1), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.triton_utils import HAS_TRITON, tl, triton | ||
|
|
||
| _DEFAULT_BLOCK_SIZE = 1024 | ||
|
|
||
|
|
||
| if HAS_TRITON: | ||
|
|
||
| @triton.jit | ||
| def _zero_out_decode_padding_kernel( | ||
| out_ptr, | ||
| seq_lens_ptr, | ||
| row_stride, | ||
| num_cols, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ) -> None: | ||
| row = tl.program_id(0) | ||
|
|
||
| if tl.load(seq_lens_ptr + row) != 0: | ||
| return | ||
|
|
||
| col_offsets = tl.arange(0, BLOCK_SIZE) | ||
| out_ptrs = out_ptr + row * row_stride + col_offsets | ||
| for c in tl.range(0, tl.cdiv(num_cols, BLOCK_SIZE)): | ||
| mask = col_offsets + c * BLOCK_SIZE < num_cols | ||
| tl.store( | ||
| out_ptrs, | ||
| tl.zeros([BLOCK_SIZE], dtype=out_ptr.dtype.element_ty), | ||
| mask=mask, | ||
| ) | ||
| out_ptrs += BLOCK_SIZE | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson @elvircrn the kernel can use a fresh pair of eyes. Thanks 🙌
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lgtm! |
||
|
|
||
|
|
||
| def _zero_out_decode_padding_triton( | ||
| out: torch.Tensor, | ||
| seq_lens: torch.Tensor, | ||
| ) -> None: | ||
| """Zero rows in `out` where `seq_lens == 0` using a Triton kernel.""" | ||
| if not out.is_cuda or not seq_lens.is_cuda: | ||
| raise ValueError("out and seq_lens must be CUDA tensors.") | ||
| if out.size(0) != seq_lens.numel(): | ||
| raise ValueError( | ||
| f"out.size(0) {out.size()} must matchseq_lens.numel() ({seq_lens.numel()})." | ||
| ) | ||
| if not out.is_contiguous(): | ||
| raise ValueError("out must be contiguous.") | ||
|
|
||
| BLOCK_SIZE = 1024 | ||
|
|
||
| out_2d = out.view(out.size(0), -1) | ||
| grid = (out_2d.size(0),) | ||
| _zero_out_decode_padding_kernel[grid]( | ||
| out_2d, | ||
| seq_lens, | ||
| out_2d.stride(0), | ||
| out_2d.size(1), | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
|
|
||
|
|
||
| def zero_out_decode_padding(out: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor: | ||
| if HAS_TRITON: | ||
| _zero_out_decode_padding_triton(out, seq_lens) | ||
| else: | ||
| out[seq_lens == 0] = 0 | ||
| return out | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.