Skip to content

Conversation

@hukongyi
Copy link
Contributor

@hukongyi hukongyi commented Dec 25, 2025

In example_gqa_fwd_varlen.py, the bx index is already relative to each sequence's start. Adding q_current_seqlen to the loop range calculation was redundant and potentially incorrect for causal masking logic. This PR cleans up the calculation to correctly reflect the block range.

Summary by CodeRabbit

  • New Features

    • Added optional causal mode with a CLI flag to run causal attention variants.
  • Refactor

    • Validation now uses Flash Attention as the reference and removes the legacy reference comparison.
    • Unified key/value handling for variable-length workloads; improved causal masking and iteration logic.
    • Numerical stability improvements in causal paths (replaced extreme infinities with stable values).
  • Tests

    • Expanded coverage to include GQA example and both causal/non-causal MHA variants.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: hukongyi <hukongyi@cmbchina.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 25, 2025

📝 Walkthrough

Walkthrough

Removed reference attention implementations and switched examples to call flash_attn for varlen unpadded Q/K/V; unified K/V handling into a kv-packed path, added a causal mode with per-block offset logic affecting loop ranges and masking, updated CLI/tests to exercise causal and non-causal runs. (50 words)

Changes

Cohort / File(s) Summary
GQA varlen example
examples/flash_attention/example_gqa_fwd_varlen.py
Removed attention_ref; inline-import and call flash_attn.flash_attn_varlen_func for unpadded Q/K/V; replaced out_ref validation with Flash Attention outputs (fa_out_unpad/fa_out); introduced offset = kv_current_seqlen - q_current_seqlen to adjust loop_range and inner masking for causal mode; guarded division in causal accumulation; removed unused einops imports.
MHA varlen example & CLI
examples/flash_attention/example_mha_fwd_varlen.py
Removed attention_ref; unified K/V indexing to kv_start_idx/kv_end_idx/kv_current_seqlen; added causal: bool = False parameter and --is_causal CLI flag; applied offset-based causal masking and guards (including replacing some -inf with -1e9); ensured writes only for valid unpadded Q positions; removed reference comparison in favor of flash_attn.
Tests
examples/flash_attention/test_example_flash_attention.py
Added import of example_gqa_fwd_varlen; extended test_example_mha_fwd_varlen to run both causal=False and causal=True; added test_example_gqa_fwd_varlen running both causal modes.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through loops both wide and small,
Replaced the old ref with flash for all,
Offset in tow, causal flags unfurled,
Tests and kernels prance through the world,
A tiny rabbit cheered — attention for all! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing bugs in varlen attention forward examples that occur when sequence length of queries differs from sequence length of keys/values.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999
Copy link
Member

cc @Rachmanino

@Rachmanino
Copy link
Collaborator

Rachmanino commented Dec 25, 2025

hi @hukongyi, thanks for your feedback. Actually instead of removing q_current_seqlen here, i think we should replace it with kv_current_seqlen - q_current_seqlen to correctly handle cases where seqlen_q != seqlen_kv. This is because that we usually use right alignment (checkout fa2 https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/cute/mask.py#L197 as an example). I'll double check this later

@hukongyi hukongyi marked this pull request as draft December 25, 2025 10:42
Signed-off-by: hukongyi <hukongyi@cmbchina.com>
@LeiWang1999
Copy link
Member

and I think we should also add a test to test the causal implementation :)

@hukongyi
Copy link
Contributor Author

hi @hukongyi, thanks for your feedback. Actually instead of removing q_current_seqlen here, i think we should replace it with kv_current_seqlen - q_current_seqlen to correctly handle cases where seqlen_q != seqlen_kv. This is because that we usually use right alignment (checkout fa2 https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/cute/mask.py#L197 as an example). I'll double check this later

Hi @Rachmanino ,
I have converted this PR to a Draft and pushed my latest changes so you can review the full implementation directly.
I attempted to implement the right-alignment logic as discussed (using offset = kv_len - q_len) and updated both the TileLang kernel and the PyTorch reference accordingly. However, the correctness tests are currently failing.
As I am new to TileLang, I suspect I might be missing some nuances regarding the loop bounds or masking syntax. I would really appreciate it if you could take a look at the code when you have a moment and point me in the right direction.
Thanks for your patience!

@Rachmanino
Copy link
Collaborator

hi @hukongyi, thanks for your feedback. Actually instead of removing q_current_seqlen here, i think we should replace it with kv_current_seqlen - q_current_seqlen to correctly handle cases where seqlen_q != seqlen_kv. This is because that we usually use right alignment (checkout fa2 https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/cute/mask.py#L197 as an example). I'll double check this later

Hi @Rachmanino , I have converted this PR to a Draft and pushed my latest changes so you can review the full implementation directly. I attempted to implement the right-alignment logic as discussed (using offset = kv_len - q_len) and updated both the TileLang kernel and the PyTorch reference accordingly. However, the correctness tests are currently failing. As I am new to TileLang, I suspect I might be missing some nuances regarding the loop bounds or masking syntax. I would really appreciate it if you could take a look at the code when you have a moment and point me in the right direction. Thanks for your patience!

Sure, really appreciate your contribution to tilelang! I'll check this and add test for you later.

@Rachmanino Rachmanino marked this pull request as ready for review December 25, 2025 15:45
@Rachmanino Rachmanino changed the title fix(examples): correct causal loop range in GQA varlen example [Bug] Fix bugs of varlen attention forward examples caused by S_q != S_kv Dec 25, 2025
@Rachmanino
Copy link
Collaborator

This is supposed to also fix #1138

LeiWang1999
LeiWang1999 previously approved these changes Dec 25, 2025
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
examples/flash_attention/example_gqa_fwd_varlen.py (1)

137-142: Remove duplicate random seed call.

tilelang.testing.set_random_seed(0) is called twice (lines 137 and 142). One call is sufficient.

🔎 Proposed fix
     tilelang.testing.set_random_seed(0)
 
     if is_causal:
         total_flops *= 0.5
 
-    tilelang.testing.set_random_seed(0)
-
     dtype = torch.float16
examples/flash_attention/example_mha_fwd_varlen.py (2)

3-12: Remove duplicate torch import.

torch is imported twice (lines 3 and 10). Remove one of them.

🔎 Proposed fix
 import torch
 import tilelang
 import tilelang.language as T
 import tilelang.testing
 import argparse
 from tilelang.profiler import do_bench
 from tilelang.autotuner import set_autotune_inputs, autotune
 
-import torch
 from varlen_utils import generate_random_padding_mask, generate_qkv
 import itertools

154-154: Causal parameter wiring is complete but naming is inconsistent.

The CLI flag --is_causal maps to parameter causal in main() and run_regression_perf(). The relevant code snippet shows example_mha_fwd_bshd_wgmma_pipelined.py uses is_causal as the parameter name. Consider aligning for consistency across examples.

🔎 Optional: Align parameter naming
-def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False):
+def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, is_causal: bool = False, tune: bool = False):

Then update all references from causal to is_causal within the function body.

Also applies to: 232-232, 278-282

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a5bb44f and 9bc5cd4.

📒 Files selected for processing (2)
  • examples/flash_attention/example_gqa_fwd_varlen.py
  • examples/flash_attention/example_mha_fwd_varlen.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_mha_fwd_varlen.py (1)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
  • run_regression_perf (153-156)
🔇 Additional comments (9)
examples/flash_attention/example_gqa_fwd_varlen.py (4)

68-74: Right-alignment offset logic looks correct.

The offset calculation kv_current_seqlen - q_current_seqlen correctly implements right-alignment for causal attention where S_q != S_kv. The loop range correctly limits iteration to blocks containing visible KV positions.


79-86: Causal masking with offset is correct.

The condition (bx * block_M + i + offset < k * block_N + j) correctly masks KV positions that the query token cannot attend to under right-aligned causal semantics.


117-120: Guard for tokens with no visible context is correct.

When sq > skv, some query tokens cannot see any KV positions. The guard bx * block_M + i + offset < 0 correctly identifies these tokens and zeros their output instead of dividing by zero (since logsum would be 0).


179-193: Flash Attention validation integration is correct.

The code properly uses flash_attn.flash_attn_varlen_func with correct parameter order and types. All required parameters (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal) are correctly passed and align with the documented API. The tolerance of rtol=1e-2, atol=1e-2 is appropriate for float16 attention computations.

examples/flash_attention/example_mha_fwd_varlen.py (5)

64-70: Index variable naming is consistent and clear.

Using kv_start_idx, kv_end_idx, and kv_current_seqlen clearly indicates these indices apply to both K and V tensors.


80-85: Offset and loop range logic is consistent with GQA implementation.

The right-alignment offset calculation and loop range limiting are correctly implemented, matching the pattern in example_gqa_fwd_varlen.py.


92-104: Causal masking implementation is correct and consistent.

The offset-adjusted causal mask and OOB handling match the GQA implementation, ensuring consistent behavior across attention variants.


142-149: Output guard and per-element write are correct.

The zero-output guard for tokens that can see nothing and the bounds-checked per-element write are correctly implemented.


201-215: Flash Attention validation is correctly integrated.

The validation against flash_attn.flash_attn_varlen_func with the same tolerance as the GQA example ensures consistent correctness verification.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_fwd_varlen.py (2)

179-179: Consider moving import to top of file.

The inline import of flash_attn at line 179 works but moving it to the top of the file would improve code organization and make dependencies more visible.

🔎 Proposed refactor

At the top of the file, add:

 # ruff: noqa
 import argparse
 import torch
 import tilelang
 import tilelang.language as T
 import tilelang.testing
 from tilelang.profiler import do_bench
 from varlen_utils import generate_random_padding_mask, generate_qkv
+import flash_attn

Then remove the import at line 179:

-    import flash_attn
-
     fa_out_unpad = flash_attn.flash_attn_varlen_func(

179-193: Move flash_attn import to the top of the file.

The API usage is correct and fully supported—the parameter order and types match the documented flash_attn_varlen_func API, and causal attention with variable-length sequences is confirmed to work. However, the inline import flash_attn at line 179 should be moved to the module-level imports at the top (lines 1–8) to follow Python conventions.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9bc5cd4 and 675fc54.

📒 Files selected for processing (1)
  • examples/flash_attention/example_gqa_fwd_varlen.py
🔇 Additional comments (1)
examples/flash_attention/example_gqa_fwd_varlen.py (1)

79-91: Masking logic correctly implements right-aligned causal attention.

The condition at line 82 correctly masks positions where the query token (adjusted by offset for right-alignment) would see future key tokens. The additional boundary check at line 83 properly handles out-of-bounds positions.

Comment on lines +68 to 74
offset = kv_current_seqlen - q_current_seqlen # always align on the right
max_visible_k_idx = offset + (bx + 1) * block_M
loop_range = (
T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal
else T.ceildiv(kv_current_seqlen, block_N)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Verify loop_range behavior when max_visible_k_idx is negative.

When q_current_seqlen > kv_current_seqlen (offset < 0) and bx is small, max_visible_k_idx can be negative or very small. The behavior of T.ceildiv(max_visible_k_idx, block_N) with a negative numerator may be implementation-dependent and could lead to:

  1. loop_range = 0, causing the loop at line 76 not to execute
  2. logsum[i] remaining 0 for some positions
  3. Potential division by zero at line 119 if the guard doesn't catch all cases

Consider adding an explicit guard to ensure loop_range >= 0 and that positions with loop_range = 0 are handled correctly:

🔎 Suggested fix to add explicit bounds
 offset = kv_current_seqlen - q_current_seqlen  # always align on the right
 max_visible_k_idx = offset + (bx + 1) * block_M
 loop_range = (
-    T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N))
+    T.max(0, T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)))
     if is_causal
     else T.ceildiv(kv_current_seqlen, block_N)
 )

Comment on lines 117 to +119
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
# When sq > skv, some tokens can see nothing
acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add guard for zero logsum to prevent division by zero.

While the condition bx * block_M + i + offset < 0 correctly handles tokens before the KV sequence start, there may be edge cases where offset >= 0 but the loop didn't execute (e.g., due to boundary conditions), leaving logsum[i] = 0. This would cause division by zero.

🔎 Proposed fix to add logsum guard
 for i, j in T.Parallel(block_M, dim):
     # When sq > skv, some tokens can see nothing
-    acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
+    acc_o[i, j] = 0 if (is_causal and bx * block_M + i + offset < 0) or logsum[i] == 0 else acc_o[i, j] / logsum[i]
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 117-119, the
final assignment divides acc_o[i, j] by logsum[i] without guarding against
logsum being zero; update the conditional so that if logsum[i] is zero (or below
a tiny epsilon) you set acc_o[i, j] = 0 (or the same branch as tokens that see
nothing), otherwise perform the division; implement the check in the same
if/else expression to avoid division-by-zero and consider using a small epsilon
for numerical safety.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome contributions, LGTM now.

@LeiWang1999 LeiWang1999 merged commit 2e82f37 into tile-ai:main Dec 26, 2025
9 of 10 checks passed
@Rachmanino Rachmanino changed the title [Bug] Fix bugs of varlen attention forward examples caused by S_q != S_kv [BugFix] Fix bugs of varlen attention forward examples caused by S_q != S_kv Dec 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants