-
Notifications
You must be signed in to change notification settings - Fork 441
[BugFix] Fix bugs of varlen attention forward examples caused by S_q != S_kv
#1530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix bugs of varlen attention forward examples caused by S_q != S_kv
#1530
Conversation
Signed-off-by: hukongyi <hukongyi@cmbchina.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughRemoved reference attention implementations and switched examples to call flash_attn for varlen unpadded Q/K/V; unified K/V handling into a kv-packed path, added a causal mode with per-block offset logic affecting loop ranges and masking, updated CLI/tests to exercise causal and non-causal runs. (50 words) Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
cc @Rachmanino |
|
hi @hukongyi, thanks for your feedback. Actually instead of removing |
Signed-off-by: hukongyi <hukongyi@cmbchina.com>
|
and I think we should also add a test to test the causal implementation :) |
Hi @Rachmanino , |
Sure, really appreciate your contribution to tilelang! I'll check this and add test for you later. |
S_q != S_kv
|
This is supposed to also fix #1138 |
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
examples/flash_attention/example_gqa_fwd_varlen.py (1)
137-142: Remove duplicate random seed call.
tilelang.testing.set_random_seed(0)is called twice (lines 137 and 142). One call is sufficient.🔎 Proposed fix
tilelang.testing.set_random_seed(0) if is_causal: total_flops *= 0.5 - tilelang.testing.set_random_seed(0) - dtype = torch.float16examples/flash_attention/example_mha_fwd_varlen.py (2)
3-12: Remove duplicate torch import.
torchis imported twice (lines 3 and 10). Remove one of them.🔎 Proposed fix
import torch import tilelang import tilelang.language as T import tilelang.testing import argparse from tilelang.profiler import do_bench from tilelang.autotuner import set_autotune_inputs, autotune -import torch from varlen_utils import generate_random_padding_mask, generate_qkv import itertools
154-154: Causal parameter wiring is complete but naming is inconsistent.The CLI flag
--is_causalmaps to parametercausalinmain()andrun_regression_perf(). The relevant code snippet showsexample_mha_fwd_bshd_wgmma_pipelined.pyusesis_causalas the parameter name. Consider aligning for consistency across examples.🔎 Optional: Align parameter naming
-def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False): +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, is_causal: bool = False, tune: bool = False):Then update all references from
causaltois_causalwithin the function body.Also applies to: 232-232, 278-282
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_gqa_fwd_varlen.pyexamples/flash_attention/example_mha_fwd_varlen.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_mha_fwd_varlen.py (1)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
run_regression_perf(153-156)
🔇 Additional comments (9)
examples/flash_attention/example_gqa_fwd_varlen.py (4)
68-74: Right-alignment offset logic looks correct.The offset calculation
kv_current_seqlen - q_current_seqlencorrectly implements right-alignment for causal attention whereS_q != S_kv. The loop range correctly limits iteration to blocks containing visible KV positions.
79-86: Causal masking with offset is correct.The condition
(bx * block_M + i + offset < k * block_N + j)correctly masks KV positions that the query token cannot attend to under right-aligned causal semantics.
117-120: Guard for tokens with no visible context is correct.When
sq > skv, some query tokens cannot see any KV positions. The guardbx * block_M + i + offset < 0correctly identifies these tokens and zeros their output instead of dividing by zero (sincelogsumwould be 0).
179-193: Flash Attention validation integration is correct.The code properly uses
flash_attn.flash_attn_varlen_funcwith correct parameter order and types. All required parameters (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal) are correctly passed and align with the documented API. The tolerance ofrtol=1e-2, atol=1e-2is appropriate for float16 attention computations.examples/flash_attention/example_mha_fwd_varlen.py (5)
64-70: Index variable naming is consistent and clear.Using
kv_start_idx,kv_end_idx, andkv_current_seqlenclearly indicates these indices apply to both K and V tensors.
80-85: Offset and loop range logic is consistent with GQA implementation.The right-alignment offset calculation and loop range limiting are correctly implemented, matching the pattern in
example_gqa_fwd_varlen.py.
92-104: Causal masking implementation is correct and consistent.The offset-adjusted causal mask and OOB handling match the GQA implementation, ensuring consistent behavior across attention variants.
142-149: Output guard and per-element write are correct.The zero-output guard for tokens that can see nothing and the bounds-checked per-element write are correctly implemented.
201-215: Flash Attention validation is correctly integrated.The validation against
flash_attn.flash_attn_varlen_funcwith the same tolerance as the GQA example ensures consistent correctness verification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_fwd_varlen.py (2)
179-179: Consider moving import to top of file.The inline import of
flash_attnat line 179 works but moving it to the top of the file would improve code organization and make dependencies more visible.🔎 Proposed refactor
At the top of the file, add:
# ruff: noqa import argparse import torch import tilelang import tilelang.language as T import tilelang.testing from tilelang.profiler import do_bench from varlen_utils import generate_random_padding_mask, generate_qkv +import flash_attnThen remove the import at line 179:
- import flash_attn - fa_out_unpad = flash_attn.flash_attn_varlen_func(
179-193: Move flash_attn import to the top of the file.The API usage is correct and fully supported—the parameter order and types match the documented
flash_attn_varlen_funcAPI, and causal attention with variable-length sequences is confirmed to work. However, the inlineimport flash_attnat line 179 should be moved to the module-level imports at the top (lines 1–8) to follow Python conventions.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/flash_attention/example_gqa_fwd_varlen.py
🔇 Additional comments (1)
examples/flash_attention/example_gqa_fwd_varlen.py (1)
79-91: Masking logic correctly implements right-aligned causal attention.The condition at line 82 correctly masks positions where the query token (adjusted by offset for right-alignment) would see future key tokens. The additional boundary check at line 83 properly handles out-of-bounds positions.
| offset = kv_current_seqlen - q_current_seqlen # always align on the right | ||
| max_visible_k_idx = offset + (bx + 1) * block_M | ||
| loop_range = ( | ||
| T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) | ||
| T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) | ||
| if is_causal | ||
| else T.ceildiv(kv_current_seqlen, block_N) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify loop_range behavior when max_visible_k_idx is negative.
When q_current_seqlen > kv_current_seqlen (offset < 0) and bx is small, max_visible_k_idx can be negative or very small. The behavior of T.ceildiv(max_visible_k_idx, block_N) with a negative numerator may be implementation-dependent and could lead to:
loop_range = 0, causing the loop at line 76 not to executelogsum[i]remaining 0 for some positions- Potential division by zero at line 119 if the guard doesn't catch all cases
Consider adding an explicit guard to ensure loop_range >= 0 and that positions with loop_range = 0 are handled correctly:
🔎 Suggested fix to add explicit bounds
offset = kv_current_seqlen - q_current_seqlen # always align on the right
max_visible_k_idx = offset + (bx + 1) * block_M
loop_range = (
- T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N))
+ T.max(0, T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)))
if is_causal
else T.ceildiv(kv_current_seqlen, block_N)
)| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] /= logsum[i] | ||
| T.copy(acc_o, O_shared) | ||
| # When sq > skv, some tokens can see nothing | ||
| acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add guard for zero logsum to prevent division by zero.
While the condition bx * block_M + i + offset < 0 correctly handles tokens before the KV sequence start, there may be edge cases where offset >= 0 but the loop didn't execute (e.g., due to boundary conditions), leaving logsum[i] = 0. This would cause division by zero.
🔎 Proposed fix to add logsum guard
for i, j in T.Parallel(block_M, dim):
# When sq > skv, some tokens can see nothing
- acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
+ acc_o[i, j] = 0 if (is_causal and bx * block_M + i + offset < 0) or logsum[i] == 0 else acc_o[i, j] / logsum[i]🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 117-119, the
final assignment divides acc_o[i, j] by logsum[i] without guarding against
logsum being zero; update the conditional so that if logsum[i] is zero (or below
a tiny epsilon) you set acc_o[i, j] = 0 (or the same branch as tokens that see
nothing), otherwise perform the division; implement the check in the same
if/else expression to avoid division-by-zero and consider using a small epsilon
for numerical safety.
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome contributions, LGTM now.
S_q != S_kvS_q != S_kv
In example_gqa_fwd_varlen.py, the bx index is already relative to each sequence's start. Adding q_current_seqlen to the loop range calculation was redundant and potentially incorrect for causal masking logic. This PR cleans up the calculation to correctly reflect the block range.
Summary by CodeRabbit
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.