-
Notifications
You must be signed in to change notification settings - Fork 469
[BugFix] Fix bugs of varlen attention forward examples caused by S_q != S_kv
#1530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
39846c3
cb1e36b
3e6ff1a
a5bb44f
9bc5cd4
675fc54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,55 +4,10 @@ | |
| import tilelang | ||
| import tilelang.language as T | ||
| import tilelang.testing | ||
| from einops import rearrange, repeat | ||
| from tilelang.profiler import do_bench | ||
| from varlen_utils import generate_random_padding_mask, generate_qkv | ||
|
|
||
|
|
||
| def attention_ref( | ||
| q, | ||
| k, | ||
| v, | ||
| query_padding_mask=None, | ||
| key_padding_mask=None, | ||
| causal=False, | ||
| window_size=(-1, -1), | ||
| upcast=True, | ||
| ): | ||
| if causal: | ||
| window_size = (window_size[0], 0) | ||
| dtype_og = q.dtype | ||
| if upcast: | ||
| q, k, v = q.float(), k.float(), v.float() | ||
| b, T, Hq, D = q.shape | ||
| S = k.shape[1] | ||
| scale = (1.0 / D) ** 0.5 | ||
| k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) | ||
| v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) | ||
| scores = torch.einsum("bthd,bshd->bhts", q, k) | ||
| left, right = window_size | ||
| left = S if left is None or left < 0 else int(left) | ||
| right = S if right is None or right < 0 else int(right) | ||
| t_idx = torch.arange(T, device=scores.device)[:, None] | ||
| s_idx = torch.arange(S, device=scores.device)[None, :] | ||
| visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) | ||
| visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) | ||
| if key_padding_mask is not None: | ||
| k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") | ||
| visible_mask = visible_mask & k_keep | ||
| neg_inf = torch.finfo(scores.dtype).min | ||
| scores = scores * scale | ||
| scores = scores.masked_fill(~visible_mask, neg_inf) | ||
| attention = torch.softmax(scores, dim=-1).to(v.dtype) | ||
| if query_padding_mask is not None: | ||
| q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") | ||
| attention = attention.masked_fill(~q_keep, 0.0) | ||
| output = torch.einsum("bhts,bshd->bthd", attention, v) | ||
| if query_padding_mask is not None: | ||
| output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) | ||
| return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) | ||
|
|
||
|
|
||
| @tilelang.jit( | ||
| out_idx=[6], | ||
| pass_configs={ | ||
|
|
@@ -110,8 +65,10 @@ def main( | |
| T.fill(logsum, 0) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
|
|
||
| offset = kv_current_seqlen - q_current_seqlen # always align on the right | ||
| max_visible_k_idx = offset + (bx + 1) * block_M | ||
| loop_range = ( | ||
| T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) | ||
| T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) | ||
| if is_causal | ||
| else T.ceildiv(kv_current_seqlen, block_N) | ||
| ) | ||
|
|
@@ -122,7 +79,7 @@ def main( | |
| if is_causal: | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else( | ||
| (bx * block_M + i < k * block_N + j) | ||
| (bx * block_M + i + offset < k * block_N + j) | ||
| or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), | ||
| -1e9, | ||
| 0, | ||
|
|
@@ -158,9 +115,10 @@ def main( | |
| T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] /= logsum[i] | ||
| T.copy(acc_o, O_shared) | ||
| # When sq > skv, some tokens can see nothing | ||
| acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] | ||
|
Comment on lines
117
to
+119
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add guard for zero logsum to prevent division by zero. While the condition 🔎 Proposed fix to add logsum guard for i, j in T.Parallel(block_M, dim):
# When sq > skv, some tokens can see nothing
- acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
+ acc_o[i, j] = 0 if (is_causal and bx * block_M + i + offset < 0) or logsum[i] == 0 else acc_o[i, j] / logsum[i]🤖 Prompt for AI Agents |
||
|
|
||
| T.copy(acc_o, O_shared) | ||
| for i, d in T.Parallel(block_M, dim): | ||
| if bx * block_M + i < q_current_seqlen: | ||
| Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] | ||
|
|
@@ -218,15 +176,22 @@ def main( | |
| out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) | ||
| out = output_pad_fn(out_unpad) | ||
|
|
||
| out_ref, _ = attention_ref( | ||
| q, | ||
| k, | ||
| v, | ||
| query_padding_mask=query_padding_mask, | ||
| key_padding_mask=key_padding_mask, | ||
| import flash_attn | ||
|
|
||
| fa_out_unpad = flash_attn.flash_attn_varlen_func( | ||
| q_unpad, | ||
| k_unpad, | ||
| v_unpad, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| 0.0, | ||
| causal=is_causal, | ||
| ) | ||
| torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) | ||
| fa_out = output_pad_fn(fa_out_unpad) | ||
| torch.testing.assert_close(out, fa_out, rtol=1e-2, atol=1e-2) | ||
|
|
||
| print("All checks passed.✅") | ||
| latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) | ||
| print("Tile-lang: {:.2f} ms".format(latency)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify loop_range behavior when max_visible_k_idx is negative.
When
q_current_seqlen > kv_current_seqlen(offset < 0) andbxis small,max_visible_k_idxcan be negative or very small. The behavior ofT.ceildiv(max_visible_k_idx, block_N)with a negative numerator may be implementation-dependent and could lead to:loop_range = 0, causing the loop at line 76 not to executelogsum[i]remaining 0 for some positionsConsider adding an explicit guard to ensure
loop_range >= 0and that positions withloop_range = 0are handled correctly:🔎 Suggested fix to add explicit bounds