diff --git a/3rdparty/tvm b/3rdparty/tvm index 391d3f7cd..8d494caca 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 391d3f7cda9abdcb60c57e472dbc4800ae98d5a8 +Subproject commit 8d494cacae52b2ec73f2717431190b1ecd5df6ce diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index 6ee595921..6aad32bdb 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -489,7 +489,6 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi sorted_token_ids, expert_ids, ) - print("Tilelang kernel run finished.") ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... @@ -561,9 +560,9 @@ def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, topk=4, E=32, fas if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm - parser.add_argument("--N", type=int, default=5760, help="N") - parser.add_argument("--K", type=int, default=2944, help="K") + parser.add_argument("--M", type=int, default=256, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--N", type=int, default=256, help="N") + parser.add_argument("--K", type=int, default=256, help="K") parser.add_argument("--scale_size", type=int, default=32, help="scale size") parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token parser.add_argument("--E", type=int, default=32, help="E") # number of experts diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 7c47252aa..864ff3e54 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -1,6 +1,4 @@ import torch -import triton -import triton.language as tl import math import argparse import tilelang @@ -22,167 +20,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -@triton.jit -def _fwd_inner( - q, - k_ptrs, - v_ptrs, - s_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N: tl.constexpr, -): - """Inner loop computation for attention""" - - for blk_idx in tl.range(lo, hi): - start_n = blk_idx * BLOCK_N - k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) - v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) - - qk = tl.dot(q, k) - qk *= softmax_scale - qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) - - row_max = tl.max(qk, 1) - tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) - - m_ij = tl.maximum(m_i, row_max) - qk -= m_ij[:, None] - p = tl.math.exp(qk) - l_ij = tl.sum(p, 1) - alpha = tl.math.exp(m_i - m_ij) - l_i = l_i * alpha + l_ij - m_i = m_ij - acc *= alpha[:, None] - p = p.to(v.type.element_ty) - acc += tl.dot(p, v) - - return m_i, l_i, acc - - -@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], - key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], -) -@triton.jit -def _fwd_kernel_varlen( - Q, # [token_q = b, h_q, dim] - K, # [token_k, h_kv, dim] - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - stride_qt, - stride_qh, - stride_qd, - stride_kt, - stride_kh, - stride_kd, - stride_vt, - stride_vh, - stride_vd, - stride_ot, - stride_oh, - stride_od, - stride_sb, - stride_sh, - stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] - gqa_group_size: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, -): - off_z = tl.program_id(0) - off_h_for_kv = tl.program_id(1) - off_h_q = off_h_for_kv * gqa_group_size - - cu_k_start = tl.load(cu_seqlens_k + off_z) - cu_k_end = tl.load(cu_seqlens_k + off_z + 1) - - seqlen_k = cu_k_end - cu_k_start - - offs_h = tl.arange(0, BLOCK_H) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_D) - - Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh - K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh - V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh - O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh - S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh - - mask_h = offs_h < gqa_group_size - q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) - - if s_aux is not None: - sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) - l_i = tl.zeros([BLOCK_H], dtype=tl.float32) - m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink - else: - l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) - m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) - - acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) - - k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd - v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd - - lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) - m_i, l_i, acc = _fwd_inner( - q, - k_ptrs, - v_ptrs, - S_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen_k, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N, - ) - - if s_aux is not None: - sink = tl.math.exp(sink - m_i) - l_i = l_i + sink - acc = acc / l_i[:, None] - - else: - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - - for blk_idx in tl.range(lo, hi): - s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) - s = tl.exp(s - m_i) / l_i - tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) - - acc = acc.to(O.dtype.element_ty) - - tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) - - def get_configs(): import itertools @@ -212,7 +49,6 @@ def flashattn( kv_group_num = heads // k_heads valid_block_H = min(block_H, kv_group_num) - # TODO: check if max_seqlen_kv is correct for varlen case @T.prim_func def flashattn_gqa_decode_no_split( @@ -224,7 +60,7 @@ def flashattn_gqa_decode_no_split( Output: T.Tensor(shape_o, dtype), S: T.Tensor(shape_s, dtype), ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -237,12 +73,10 @@ def flashattn_gqa_decode_no_split( scores_scale = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) - S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) - # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + S_shared_cast = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) s_aux_shared = T.alloc_shared([block_H], T.float32) - bid = bx - hid = by cur_kv_head = hid // (kv_group_num // valid_block_H) cur_start_k = cu_seqlens_k[bid] @@ -254,30 +88,22 @@ def flashattn_gqa_decode_no_split( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], - # -T.infinity(accum_dtype)) acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # scores_max_prev is m_i - # scores_max is row_max->m_ij in triton T.copy(scores_max, S_shared[:, k]) - # scores_scale is alpha in triton for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - # scores_sum is l_ij in triton - # logsum is l_i in triton for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] T.copy(acc_s, acc_s_cast) @@ -294,327 +120,98 @@ def flashattn_gqa_decode_no_split( acc_o[i, j] /= logsum[i] for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] - # T.copy(S_shared, S_fragment) - # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): - # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - # T.copy(S_fragment, S_shared) - T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared, S_shared_cast) + T.copy(S_shared_cast[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - # TODO: split version return flashattn_gqa_decode_no_split -def flash_attn_with_attn_pool_decode_tilelang( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, - tl_kernel=None, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - # assert K.is_contiguous() - # assert V.is_contiguous() +def ref_attention(q, k, v, k_seqlens, q_heads, sink=None): + """ + Compute reference attention output and weights. + Args: + q: [b, q_heads, head_size] + k, v: [b, kv_heads, max_seqlen, head_size] + k_seqlens: [b] actual sequence lengths + sink: [q_heads] optional sink values + Returns: output [b, q_heads, head_size], attn_weights [b, q_heads, max_seqlen] + """ + batch_size, kv_heads, max_seqlen, head_size = k.shape + softmax_scale = 1.0 / math.sqrt(head_size) - gqa_group_size = q_h // k_h + # Expand KV heads and compute attention scores + k = repeat_kv(k, q_heads // kv_heads) + v = repeat_kv(v, q_heads // kv_heads) + logits = torch.matmul(q.unsqueeze(2), k.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) - O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) + # Mask invalid positions + mask = torch.arange(max_seqlen, device=q.device).expand(batch_size, -1) >= k_seqlens.unsqueeze(1) + logits.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf")) - if use_per_kv_head_sparse_index: - S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) - else: - S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) - - return O_tl, S_tl - - -def flash_attn_with_attn_pool_decode( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - # assert K.is_contiguous() - # assert V.is_contiguous() - - gqa_group_size = q_h // k_h - - BLOCK_D = head_size - BLOCK_N = block_size - BLOCK_H = 64 - - O = torch.zeros_like(Q) - S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) - - def grid(META): - return (batch, k_h) - - with torch.cuda.device(Q.device.index): - _fwd_kernel_varlen[grid]( - Q, - K, - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - *Q.stride(), - *K.stride(), - *V.stride(), - *O.stride(), - *S.stride(), - gqa_group_size, - BLOCK_H=BLOCK_H, - BLOCK_N=BLOCK_N, - BLOCK_D=BLOCK_D, - ) - - if use_per_kv_head_sparse_index: - S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + if sink is None: + attn_weights = logits.softmax(dim=-1) else: - S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) + # Sink attention: softmax with additional sink term + sink_expanded = sink.view(1, q_heads, 1, 1) + logits_max = torch.maximum(logits.max(dim=-1, keepdim=True).values, sink_expanded) + exp_logits = torch.exp(logits - logits_max) + attn_weights = exp_logits / (exp_logits.sum(dim=-1, keepdim=True) + torch.exp(sink_expanded - logits_max)) - return O, S + attn_weights.masked_fill_(mask.unsqueeze(1).unsqueeze(2), 0.0) + output = torch.matmul(attn_weights.to(v.dtype), v).squeeze(2) + return output, attn_weights.squeeze(2) def test_varlen_decode_main(args): - """Test decode kernel with variable sequence lengths""" - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - max_k_seqlen = args.k_seqlen # Use as max sequence length - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size + """Test decode kernel with variable sequence lengths.""" + batch_size, q_heads, kv_heads = args.batch_size, args.q_heads, args.kv_heads + max_k_seqlen, head_size, block_size = args.k_seqlen, args.head_size, args.block_size dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Generate variable length k sequences + # Generate variable length sequences and cumulative lengths k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) - print(f"k_seqlens: {k_seqlens}") - - # Generate cumulative sequence lengths for k cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) - total_k_tokens = 0 - for i in range(batch_size): - cu_seqlens_k[i] = total_k_tokens - total_k_tokens += k_seqlens[i] - cu_seqlens_k[batch_size] = total_k_tokens - - print(f"cu_seqlens_k: {cu_seqlens_k}") + cu_seqlens_k[1:] = torch.cumsum(k_seqlens, dim=0).to(torch.int32).cuda() + total_k_tokens = cu_seqlens_k[-1].item() - # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + # Generate input tensors + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - print(f"Actual max_seqlen_k: {max_seqlen_k}") - print(f"q_decode shape: {q_decode.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q_decode.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + # Run tilelang kernel + tilelang.disable_cache() + tl_kernel = flashattn(batch_size, q_heads, kv_heads, max_k_seqlen, total_k_tokens, head_size, args.test_sink) + O_tl, S_tl = tl_kernel(q, k_varlen, v_varlen, cu_seqlens_k, sink) + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_heads, 1), stride=(q_heads, 1)) - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - ) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - ) + # Mask out invalid S positions for i in range(batch_size): - S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 - - # Create torch reference - pad tensors for comparison - k_padded_list = [] - v_padded_list = [] + valid_blocks = math.ceil(k_seqlens[i].item() / block_size) + S_tl[i, :, valid_blocks:] = 0 + # Prepare padded tensors for reference + actual_max = int(k_seqlens.max()) + k_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) for i in range(batch_size): - actual_k_len = k_seqlens[i] - - # Extract and pad k, v for this batch - k_start = cu_seqlens_k[i] - k_end = cu_seqlens_k[i + 1] - - # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - - k_padded[:actual_k_len] = k_varlen[k_start:k_end] - v_padded[:actual_k_len] = v_varlen[k_start:k_end] - - k_padded_list.append(k_padded) - v_padded_list.append(v_padded) - - # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - - # Expand q to match kv heads: [b, q_heads, 1, head_size] - q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] - - print(f"q_expanded shape: {q_expanded.shape}") - print(f"k_padded_batched shape: {k_padded_batched.shape}") - print(f"v_padded_batched shape: {v_padded_batched.shape}") + seq_len = k_seqlens[i].item() + k_padded[i, :, :seq_len] = k_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) + v_padded[i, :, :seq_len] = v_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) - # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - - if sink is None: - # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float("-inf") - - attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float("-inf") - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] - - O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] - - # Compute attention score pooling for S - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, max_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True, - ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] - - print(f"O_triton shape: {O_triton.shape}") - print(f"O_tilelang shape: {O_tilelang.shape}") - print(f"O_torch shape: {O_torch.shape}") - print(f"S_triton shape: {S_triton.shape}") - print(f"S_tilelang shape: {S_tilelang.shape}") - print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + # Compute reference + O_ref, attn_weights = ref_attention(q, k_padded, v_padded, k_seqlens.cuda(), q_heads, sink) + S_ref = torch.max_pool2d(attn_weights, kernel_size=(q_heads, block_size), stride=(q_heads, block_size), ceil_mode=True).to(dtype) # Compare results - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") - - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs( - S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)] - ) - ) - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - - assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)], - atol=1e-2, - rtol=1e-2, - ), f"Score mismatch: {max_diff_s_tl.item()}" - + num_blocks = math.ceil(actual_max / block_size) + assert torch.allclose(O_tl, O_ref, atol=1e-2, rtol=1e-2), f"Output mismatch: {(O_tl - O_ref).abs().max()}" + assert torch.allclose(S_tl[:, :, :num_blocks], S_ref[:, :, :num_blocks], atol=1e-2, rtol=1e-2), "Score mismatch" print("✅ All tests passed!") @@ -658,66 +255,25 @@ def speed_benchmark_decode_comparison(args): q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(" Using sink attention with sink values") - - print("Setup complete:") - print(f" Total K tokens: {total_k_tokens}") - print(f" Actual max K seq len: {max_seqlen_k}") + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None if args.test_varlen: print(f" K sequence lengths: {k_seqlens.tolist()}") - # Warmup - num_tokens, q_h, head_size = q_decode.shape + _, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + def run_once(): + tl_kernel(q_decode, k_varlen, v_varlen, cu_seqlens_k, sink) + # Benchmark print("⚡ Benchmarking Tilelang kernel (100 iterations)...") tilelang_time = do_bench( - flash_attn_with_attn_pool_decode_tilelang, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - False, - tl_kernel, + run_once, ) print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") - # Benchmark - print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench( - flash_attn_with_attn_pool_decode, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - ) - print(f"Average decode kernel time Triton: {triton_time:.3f} ms") - - print(f"Speedup: {(triton_time / tilelang_time):.3f}") - def main(): args = argparse.Namespace( @@ -755,7 +311,9 @@ def main(): args.dtype = T.float16 args.num_split = 1 - if args.benchmark: - speed_benchmark_decode_comparison(args) - else: - test_varlen_decode_main(args) + # if args.benchmark: + # speed_benchmark_decode_comparison(args) + # else: + # test_varlen_decode_main(args) + + speed_benchmark_decode_comparison(args) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py deleted file mode 100644 index 87748512d..000000000 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py +++ /dev/null @@ -1,550 +0,0 @@ -import torch -import math -import argparse -import tilelang -import tilelang.language as T -from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench - -torch.manual_seed(0) - - -def get_configs(): - import itertools - - block_N = [64, 128] - block_H = [64] - num_split = [1] - num_stages = [1, 2, 3] - threads = [128] - _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - - configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] - return configs - - -# @autotune(configs=get_configs(), warmup=10, rep=10) -@tilelang.jit(out_idx=[-2, -1]) -def flashattn( - batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - page_block_size, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128, -): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape_q = [batch, heads, dim] - shape_k = [total_seqlen_k, k_heads, dim] - shape_v = [total_seqlen_k, k_heads, dim] - shape_o = [batch, heads, dim] - shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] - dtype = T.float16 - accum_dtype = T.float32 - kv_group_num = heads // k_heads - assert page_block_size >= block_N and page_block_size % block_N == 0, ( - "page_block_size must be larger than block_N and a multiple of block_N" - ) - - valid_block_H = min(block_H, kv_group_num) - # TODO: check if max_seqlen_kv is correct for varlen case - - @T.prim_func - def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], T.int32), - s_aux: T.Tensor([heads], T.float32), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), - ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([valid_block_H, dim], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim], accum_dtype) - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) - s_aux_shared = T.alloc_shared([block_H], T.float32) - - bid = bx - hid = by - cur_kv_head = hid // (kv_group_num // valid_block_H) - - cur_start_k = cu_seqlens_k[bid] - cur_end_k = cu_seqlens_k[bid + 1] - cur_seqlen_k = cur_end_k - cur_start_k - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) - for k in T.Pipelined(loop_range, num_stages=num_stages): - k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size - T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # scores_max_prev is m_i - # scores_max is row_max->m_ij in triton - T.copy(scores_max, S_shared[:, k]) - # scores_scale is alpha in triton - for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - # scores_sum is l_ij in triton - # logsum is l_i in triton - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim): - acc_o[i, j] *= scores_scale[i] - v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size - T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - if has_sink: - T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) - for i in T.Parallel(block_H): - logsum[i] += s_aux_shared[i] - for i, j in T.Parallel(block_H, dim): - acc_o[i, j] /= logsum[i] - for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): - S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] - for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - - # TODO: split version - return flashattn_gqa_decode_no_split - - -def flash_attn_with_attn_pool_decode_tilelang( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, - tl_kernel=None, - block_table: torch.Tensor = None, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - assert K.is_contiguous() - assert V.is_contiguous() - - gqa_group_size = q_h // k_h - - O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) - O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) - - if use_per_kv_head_sparse_index: - S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) - else: - S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) - - return O_tl, S_tl - - -def test_varlen_decode_main(args): - """Test decode kernel with variable sequence lengths""" - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - max_k_seqlen = args.k_seqlen # Use as max sequence length - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - - print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Generate variable length k sequences - k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) - print(f"k_seqlens: {k_seqlens}") - - # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) - total_k_tokens = 0 - for i in range(batch_size): - cu_seqlens_k[i] = total_k_tokens - total_k_tokens += k_seqlens[i] - cu_seqlens_k[batch_size] = total_k_tokens - - print(f"cu_seqlens_k: {cu_seqlens_k}") - - # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - print(f"Actual max_seqlen_k: {max_seqlen_k}") - print(f"q_decode shape: {q_decode.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q_decode.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - - block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) - block_cnt = 0 - for i in range(batch): - cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() - for j in range(math.ceil(cur_seqlen / page_block_size)): - block_table[i, j] = block_cnt - block_cnt += 1 - block_cnt = 0 - - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - ) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - block_table=block_table, - ) - for i in range(batch_size): - S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 - - # Create torch reference - pad tensors for comparison - k_padded_list = [] - v_padded_list = [] - - for i in range(batch_size): - actual_k_len = k_seqlens[i] - - # Extract and pad k, v for this batch - k_start = cu_seqlens_k[i] - k_end = cu_seqlens_k[i + 1] - - # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - - k_padded[:actual_k_len] = k_varlen[k_start:k_end] - v_padded[:actual_k_len] = v_varlen[k_start:k_end] - - k_padded_list.append(k_padded) - v_padded_list.append(v_padded) - - # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - - # Expand q to match kv heads: [b, q_heads, 1, head_size] - q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] - - print(f"q_expanded shape: {q_expanded.shape}") - print(f"k_padded_batched shape: {k_padded_batched.shape}") - print(f"v_padded_batched shape: {v_padded_batched.shape}") - - # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - - if sink is None: - # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float("-inf") - - attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float("-inf") - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] - - O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] - - # Compute attention score pooling for S - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, max_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True, - ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] - - print(f"O_triton shape: {O_triton.shape}") - print(f"O_tilelang shape: {O_tilelang.shape}") - print(f"O_torch shape: {O_torch.shape}") - print(f"S_triton shape: {S_triton.shape}") - print(f"S_tilelang shape: {S_tilelang.shape}") - print(f"attn_score_pooled shape: {attn_score_pooled.shape}") - - # Compare results - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") - - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - - assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( - f"Score mismatch: {max_diff_s_tl.item()}" - ) - - print("✅ All tests passed!") - - -def speed_benchmark_decode_comparison(args): - """Speed benchmark for decode kernel""" - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - max_k_seqlen = args.k_seqlen - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - - print("\n=== Decode Speed Benchmark Comparison ===") - print("Configuration:") - print(f" Batch size: {batch_size}") - print(f" Q heads: {q_heads}, KV heads: {kv_heads}") - print(f" Max K sequence length: {max_k_seqlen}") - print(f" Head size: {head_size}") - print(f" Block size: {block_size}") - print(f" Data type: {dtype}") - print(f" Variable lengths: {args.test_varlen}") - print(f" s_aux attention: {args.test_sink}") - print() - - # Generate input data - if args.test_varlen: - k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) - else: - k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) - - # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) - total_k_tokens = 0 - for i in range(batch_size): - cu_seqlens_k[i] = total_k_tokens - total_k_tokens += k_seqlens[i] - cu_seqlens_k[batch_size] = total_k_tokens - - # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(" Using sink attention with sink values") - - print("Setup complete:") - print(f" Total K tokens: {total_k_tokens}") - print(f" Actual max K seq len: {max_seqlen_k}") - if args.test_varlen: - print(f" K sequence lengths: {k_seqlens.tolist()}") - - # Warmup - num_tokens, q_h, head_size = q_decode.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - - block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) - block_cnt = 0 - for i in range(batch): - cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() - for j in range(math.ceil(cur_seqlen / page_block_size)): - block_table[i, j] = block_cnt - block_cnt += 1 - block_cnt = 0 - - # Benchmark - print("⚡ Benchmarking Tilelang kernel (100 iterations)...") - tilelang_time = do_bench( - flash_attn_with_attn_pool_decode_tilelang, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - False, - tl_kernel, - block_table, - ) - print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") - - # Benchmark - print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench( - flash_attn_with_attn_pool_decode, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - ) - print(f"Average decode kernel time Triton: {triton_time:.3f} ms") - print(f"Speedup: {(triton_time / tilelang_time):.3f}") - - -def main(): - args = argparse.Namespace( - batch_size=1, - q_heads=32, - kv_heads=8, - k_seqlen=8192, - head_size=128, - block_size=128, - dtype=T.float16, - ) - args.test_sink = True - args.test_varlen = True - args.dtype = T.float16 - args.num_split = 1 - args.page_block_size = 128 - test_varlen_decode_main(args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size") - parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") - parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") - parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") - parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") - parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") - parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") - parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") - parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") - parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") - parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") - parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") - args = parser.parse_args() - args.test_sink = True - args.test_varlen = True - args.dtype = T.float16 - args.num_split = 1 - - if args.benchmark: - speed_benchmark_decode_comparison(args) - else: - test_varlen_decode_main(args) diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index a02a92097..2cbcd8404 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -3,10 +3,8 @@ import example_gqa_decode import example_mha_inference import example_gqa_decode_varlen_logits -import example_gqa_decode_varlen_logits_paged -# TODO(lei): fix the correctness of gqa decode on sm90 @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_le(8, 9) def test_example_example_gqa_decode(): @@ -21,9 +19,5 @@ def test_example_example_gqa_decode_varlen_logits(): example_gqa_decode_varlen_logits.main() -def test_example_example_gqa_decode_varlen_logits_paged(): - example_gqa_decode_varlen_logits_paged.main() - - if __name__ == "__main__": tilelang.testing.main() diff --git a/src/transform/common/constr_visitor.h b/src/transform/common/constr_visitor.h index f906360ee..af7ae36d6 100644 --- a/src/transform/common/constr_visitor.h +++ b/src/transform/common/constr_visitor.h @@ -244,6 +244,9 @@ struct ConstrVisitor : public tir::StmtExprVisitor { Base::VisitStmt(op->body); } } + ConstrSet GetConstrSet() const { + return ConstrSet{.constrs_ = constr_stack_}; + } std::vector constr_stack_; }; } // namespace tvm::tl diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 4cc4980d8..d78a78f75 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -360,12 +360,35 @@ class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { return {barrier_id, thread_count}; } + /*! + * \brief Calculate the number of threads that satisfy current constraints. + * + * This method uses Z3's model enumeration (AllSAT) to precisely count + * how many thread IDs satisfy all current constraints. This is essential + * for cases like `if (threadIdx.x % 4 == 0)` where const_int_bound only + * gives us the range [0, 127] but the actual number of satisfying threads + * is 32 (i.e., 0, 4, 8, ..., 124). + * + * Falls back to range-based calculation if Z3 enumeration fails or returns + * an invalid result. + */ size_t CalculateThreadExtent(const IterVar &iv, const arith::ConstIntBound &bound) { if (!analyzer_->const_int_bound.IsBound(iv->var)) { return 1; } - return bound->max_value - bound->min_value + 1; + auto extent = *as_const_int(iv->dom->extent); + // Always use Z3 enumeration to count satisfying values. + // This handles constraints like `tx % 4 == 0` that const_int_bound cannot + // detect. Z3 enumeration will return the exact count of satisfying values. + int64_t z3_count = + analyzer_->z3_prover.CountSatisfyingValues(iv->var, extent); + if (z3_count > 0) { + return static_cast(z3_count); + } + + // Fallback to range-based calculation if Z3 enumeration failed + return static_cast(bound->max_value - bound->min_value + 1); } Stmt VisitStmt_(const AttrStmtNode *op) final { @@ -413,11 +436,90 @@ class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { std::unordered_map thread_count_map_; }; +/*! + * \brief Check if an if-condition depends on runtime-variable values. + * + * For sync hoisting decisions, we distinguish two types of non-uniform + * conditions: + * + * 1. Conditions that only depend on threadIdx (e.g., `threadIdx.x >= 512`): + * - The number of threads entering the if can be determined at compile time + * - ThreadPartialSyncRewriter can handle this by computing the thread count + * - No need to hoist sync + * + * 2. Conditions that depend on runtime values (e.g., `shared_mem[tx] != -1`): + * - Cannot determine at compile time how many threads will enter + * - Must hoist sync to before the if to avoid potential deadlock + * + * This checker identifies case (2) - conditions that depend on runtime values. + */ +class RuntimeDependentConditionChecker : public IRMutatorWithAnalyzer { +public: + explicit RuntimeDependentConditionChecker(arith::Analyzer *analyzer, + int warp_size = 32) + : IRMutatorWithAnalyzer(analyzer), warp_size_(warp_size) {} + + /*! + * \brief Check if expression depends on runtime-variable values. + * \return true if the expression depends on values that cannot be determined + * at compile time (e.g., shared memory loads), false if it only + * depends on compile-time known values (constants, threadIdx, + * blockIdx). + */ + bool DependsOnRuntimeValue(const PrimExpr &expr, const IterVar &iv) { + depends_on_runtime_ = false; + this->VisitExpr(expr); + auto extent_opt = as_const_int(iv->dom->extent); + ICHECK(extent_opt != nullptr) + << "DependsOnRuntimeValue: thread extent must be a " + "constant, but got: " + << iv->dom->extent; + int64_t thread_extent = *extent_opt; + { + With ctx(analyzer_, expr); + auto count = analyzer_->z3_prover.CountSatisfyingValues( + iv->var, thread_extent, /*min_consecutive=*/warp_size_); + if (count < 0) { + // failed to count satisfying values, return true + depends_on_runtime_ = true; + } + } + return depends_on_runtime_; + } + +private: + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + // Any buffer load introduces runtime dependency + // (we don't know the buffer contents at compile time) + depends_on_runtime_ = true; + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + // Check tvm_access_ptr and address_of - if used in condition, it's reading + // memory + if (op->op.same_as(builtin::tvm_access_ptr()) || + op->op.same_as(builtin::address_of())) { + depends_on_runtime_ = true; + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + // Other calls might also introduce runtime dependency + // but we'll be conservative and check children + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + +private: + bool depends_on_runtime_{false}; + int warp_size_; +}; + struct TileLangThreadSyncPlanner : public ConstrVisitor { - explicit TileLangThreadSyncPlanner(StorageScope sync_scope) - : sync_scope_(std::move(sync_scope)) { + explicit TileLangThreadSyncPlanner(StorageScope sync_scope, + int warp_size = 32) + : sync_scope_(std::move(sync_scope)), warp_size_(warp_size) { scope_.push_back(std::vector()); } + /*! \brief Storage access type */ enum AccessType : uint8_t { kRead, @@ -449,8 +551,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { AccessType type; /*! \brief The storage scope */ StorageScope scope; - /*! \brief Whether the access is double buffer write */ - bool double_buffer_write = false; /*! \brief Whether the access is pointer access */ bool is_pointer_access = false; /*! \brief Whether this access originates from an async copy context @@ -470,6 +570,16 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { StorageScope GetScope(Var buffer_var) const { return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); } + IterVar GetThreadVar(const std::string &tag) const { + for (const auto &iv : env_threads_) { + if (iv->thread_tag == tag) { + return iv; + } + } + LOG(FATAL) << "Thread variable " << tag << " not found"; + return IterVar(); + } + void VisitExpr_(const BufferLoadNode *op) final { Var buf = op->buffer->data; buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); @@ -561,25 +671,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { ConstrVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode *op) override { - if (op->attr_key == tvm::tir::attr::double_buffer_write) { - ICHECK(double_buffer_write_ == nullptr); - double_buffer_write_ = op->node.as(); - scope_.push_back(std::vector()); - ConstrVisitor::VisitStmt_(op); - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - if (!s.access.empty()) { - for (AccessEntry &e : s.access) { - if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { - e.double_buffer_write = true; - } - } - scope_.back().emplace_back(std::move(s)); - } - double_buffer_write_ = nullptr; - } else if (op->attr_key == tvm::tir::attr::coproc_scope) { + if (op->attr_key == tvm::tir::attr::coproc_scope) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); ConstrVisitor::VisitStmt_(op); @@ -588,9 +680,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); ICHECK_NE(iv->thread_tag.length(), 0U); - // analyzer_.Bind( - // iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), - // op->value)); if (!in_device_env_) { in_device_env_ = true; @@ -604,10 +693,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { ConstrVisitor::VisitStmt_(op); } env_threads_.pop_back(); - } else if (op->attr_key == tvm::tir::attr::hand_threaded) { - // skip this pass on blocks that were hand_threaded - // this avoids control flow and read/write conflicts - // between hand-threaded kernels and automatic threading } else { ConstrVisitor::VisitStmt_(op); } @@ -647,26 +732,22 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { * * Visits the if-then-else node's condition and both branches to summarize * buffer reads, writes, and synchronization events under the condition's - * constraints. If the condition is not thread-invariant, increments an - * internal condition counter for the duration of processing. + * constraints. * - * Behavior and side effects: - * - Evaluates the condition expression (using ExtractRealCondition) and - * applies it as a constraint while summarizing the then-branch. - * - For the else-branch (when present), applies the negated, - * analyzer-simplified condition - * (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint. - * - Accumulates summarized StmtEntry access information for the then/else - * branches and appends a combined StmtEntry for the IfThenElseNode into the - * current scope. - * - Temporarily toggles allow_append_ and clears curr_stmt_.access during - * condition evaluation and branch summarization. - * - Modifies internal state: scope_ (push/pop of temporary branch scopes), - * curr_stmt_.access, and condition_counter_ (incremented/decremented when the - * condition is not thread-invariant). + * IMPORTANT: If syncs are inserted inside an if-statement with a non-uniform + * condition (i.e., the condition depends on threadIdx), we must hoist the + * sync to before the if-statement. Otherwise, only some threads will reach + * the sync point, causing a deadlock. */ void VisitStmt_(const IfThenElseNode *op) final { StmtEntry s; + // Track syncs inserted before visiting the if body + std::unordered_set syncs_before_then; + std::unordered_set syncs_before_else; + for (const auto &sync : syncs_inserted_) { + syncs_before_then.insert(sync); + } + { auto guard = MakeGuard(op->condition); allow_append_ = true; @@ -694,6 +775,12 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { cond_access.end()); } } + + // Track syncs inserted after visiting then branch + for (const auto &sync : syncs_inserted_) { + syncs_before_else.insert(sync); + } + if (op->else_case) { auto guard = MakeGuard(tir::Not(op->condition)); scope_.push_back(std::vector()); @@ -705,6 +792,59 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { s.access.insert(s.access.end(), v.begin(), v.end()); } + // Check if any syncs were inserted inside the if-then-else + std::vector syncs_in_then; + std::vector syncs_in_else; + + for (const auto &sync : syncs_inserted_) { + if (syncs_before_then.count(sync) == 0 && + syncs_before_else.count(sync) != 0) { + // Sync was inserted during then branch processing + syncs_in_then.push_back(sync); + } else if (syncs_before_else.count(sync) == 0) { + // Sync was inserted during else branch processing + syncs_in_else.push_back(sync); + } + } + + bool has_syncs_inside = !syncs_in_then.empty() || !syncs_in_else.empty(); + + if (has_syncs_inside) { + // Check if the condition depends on runtime values (e.g., shared memory + // loads). If so, we cannot determine at compile time how many threads + // will enter the if, so we must hoist the sync to before the if to avoid + // potential deadlock. + // + // If the condition only depends on threadIdx (e.g., `threadIdx.x >= + // 512`), we use Z3 to check if the thread count is a multiple of 32. + // If not, ThreadPartialSyncRewriter cannot handle it properly, so we + // must also hoist the sync. + arith::Analyzer analyzer; + ConstrSet constr_set = GetConstrSet(); + constr_set.Populate(analyzer); + RuntimeDependentConditionChecker checker(&analyzer, warp_size_); + IterVar tx = GetThreadVar("threadIdx.x"); + bool depends_on_runtime = + checker.DependsOnRuntimeValue(op->condition, tx); + + if (depends_on_runtime) { + // Condition depends on runtime values - must hoist sync + // Condition depends on runtime values - must hoist sync + LOG(WARNING) + << "[ThreadSync] Hoisting sync from inside if to before if. " + << "Condition depends on runtime value: " << op->condition; + for (const auto &sync : syncs_in_then) { + syncs_inserted_.erase(sync); + } + for (const auto &sync : syncs_in_else) { + syncs_inserted_.erase(sync); + } + + // Insert sync before the if-statement itself + insert_syncs(op); + } + } + scope_.back().emplace_back(std::move(s)); } @@ -911,13 +1051,15 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { for (const AccessEntry &acc : s.access) { if (acc.type == kRead) { - if (FindConflict(writes, acc, false)) { + // Same-iteration conflict: loop=nullptr + if (FindConflict(writes, acc, nullptr)) { sync_before_stmt = true; break; } } else if (acc.type == kWrite) { - if (FindConflict(reads, acc, false) || - FindConflict(writes, acc, false)) { + // Same-iteration conflict: loop=nullptr + if (FindConflict(reads, acc, nullptr) || + FindConflict(writes, acc, nullptr)) { sync_before_stmt = true; break; } @@ -964,12 +1106,22 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (has_read_in_scope) break; } - // If there is a loop-carried dependency, insert a single sync - // before the loop rather than hoisting a sync into the loop body. - // This reduces redundant per-iteration synchronizations for cases - // where each iteration touches disjoint regions (e.g., stmatrix - // writes to shared.dyn) and only a global ordering before/after the - // loop is required. + // Loop-carried dependency analysis using symbolic iteration shift. + // We compare accesses at iteration i (end of loop, stored in + // reads/writes) with accesses at iteration i+1 (beginning of next + // iteration). By substituting loop_var -> loop_var + step in the "next + // iteration" indices, we can precisely determine if there's a true + // dependency. + // + // Examples: + // - A[i] write, A[i] read: No loop-carry (same iteration access) + // - A[i] write, A[i+1] read: After shift, comparing A[i] vs A[i+1], + // disjoint + // - A[i] write, A[i-1] read: After shift, comparing A[i] vs A[i], + // conflict! + // - A[i%2] write, A[i%2] read: After shift, comparing A[i%2] vs + // A[(i+1)%2], + // which are disjoint for modulo buffering for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry &s = seq[i]; if (syncs_inserted_.count(s.stmt) != 0) @@ -979,13 +1131,15 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { bool need_loop_sync = false; for (const AccessEntry &acc : s.access) { if (acc.type == kRead) { - if (FindConflict(writes, acc, true)) { + // Loop-carry conflict: pass loop for iteration shift analysis + if (FindConflict(writes, acc, loop)) { need_loop_sync = true; break; } } else if (acc.type == kWrite) { - if (FindConflict(reads, acc, true) || - FindConflict(writes, acc, true)) { + // Loop-carry conflict: pass loop for iteration shift analysis + if (FindConflict(reads, acc, loop) || + FindConflict(writes, acc, loop)) { need_loop_sync = true; break; } @@ -1045,12 +1199,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } } head.insert(head.end(), tail.begin(), tail.end()); - if (loop != nullptr) { - // clear double buffer flag after a loop is finished. - for (AccessEntry &e : head) { - e.double_buffer_write = false; - } - } return head; } // The syncs inserted before each statement @@ -1063,14 +1211,13 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } /*! \return whether we are in device environment. */ bool in_device_env() const { return in_device_env_; } + // whether access appending is enabled. bool allow_append_{false}; // Whether we are in device environment bool in_device_env_{false}; // Nesting depth of tma_load/tma_load_im2col calls int tma_depth_{0}; - // The current double buffer write scope. - const VarNode *double_buffer_write_{nullptr}; // the current free stmt entry. StmtEntry curr_stmt_; // The involving threads @@ -1079,6 +1226,9 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { Map buffer_data_to_buffer_; // synchronization scope StorageScope sync_scope_; + // warp size from target + int warp_size_; + void insert_syncs(const Object *obj) { if (syncs_inserted_.count(obj)) return; @@ -1210,16 +1360,33 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } output << " Flags: "; - output << "double_buffer_write=" - << (access.double_buffer_write ? "true" : "false"); - output << ", is_pointer_access=" + output << "is_pointer_access=" << (access.is_pointer_access ? "true" : "false"); output << ", is_async_copy=" << (access.is_async_copy ? "true" : "false"); LOG(WARNING) << output.str(); } + /*! + * \brief Check if two access entries conflict, considering loop-carried + * dependencies. + * + * For loop-carry analysis, we use symbolic iteration shift: instead of + * treating loop_carry as a simple flag, we substitute loop_var with + * loop_var + step in the "next iteration" access indices and check if they + * overlap with the "current iteration" access indices. + * + * This approach can prove that accesses like A[i] and A[i+1] are disjoint + * (no loop-carry dependency), while correctly detecting dependencies like + * A[i] and A[i-1] (loop-carry dependency with distance 1). + * + * \param prev The access entry from the previous/current iteration + * \param curr The access entry to check against + * \param loop The loop node for loop-carry analysis, nullptr for + * same-iteration + * \return true if the accesses conflict and need synchronization + */ bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, - bool loop_carry) { + const ForNode *loop) { // Special case: ignore conflicts between async-copy writes (e.g., TMA // loads into shared memory). Multiple async writes do not require // interspersed barriers among themselves. We still respect conflicts with @@ -1233,12 +1400,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { return false; } - // Assumes no race between threads - // Same index value means no conflicts - // TODO(tqchen) more standard set based testing. - bool has_same_index = true; - bool range_is_overlap = true; - if (prev.buffer_indices.size() != curr.buffer_indices.size()) { // They are not the same indices, should be conflict. return true; @@ -1246,7 +1407,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (prev.is_pointer_access || curr.is_pointer_access) { // For accesses created via tvm_access_ptr we may still be able to prove - // disjointness using their byte ranges. If both sides expose a touched + // disjointness using their byte ranges. If both sides expose a touched // interval and we can show they don't overlap, skip the conflict. if (prev.is_pointer_access && curr.is_pointer_access && PointerAccessIsDisjoint(prev, curr)) { @@ -1257,35 +1418,39 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { return true; } - for (size_t i = 0; i < prev.buffer_indices.size(); i++) { - auto prev_dtype = prev.dtype; - auto curr_dtype = curr.dtype; + // Build substitution map for loop-carry analysis + // For loop-carry, we compare: Iter(i) vs Iter(i+step) + // prev represents access at iteration i (end of loop body) + // curr represents access at iteration i+step (beginning of next iteration) + ffi::Map loop_shift_sub; + if (loop != nullptr) { + // Get loop step, default to 1 if not specified + PrimExpr step = make_const(loop->loop_var.dtype(), 1); + // Substitute loop_var -> loop_var + step for the "next iteration" + loop_shift_sub.Set(loop->loop_var, loop->loop_var + step); + } + // Check if indices are the same (considering loop shift) + bool has_same_index = true; + for (size_t i = 0; i < prev.buffer_indices.size(); i++) { const auto &prev_indice = prev.buffer_indices[i]; - const auto &curr_indice = curr.buffer_indices[i]; + PrimExpr curr_indice = curr.buffer_indices[i]; + + // For loop-carry, shift the curr index to represent next iteration + if (loop != nullptr) { + curr_indice = Substitute(curr_indice, loop_shift_sub); + } if (!ExprDeepEqual()(prev_indice, curr_indice)) { has_same_index = false; break; } } + if (has_same_index) { // Use Z3 to check if prev and curr constraints are equivalent. // If equivalent, the same set of threads execute both accesses, so no // sync is needed. - // - // Formally, let P(t) denote the predicate for prev's constraint set and - // C(t) denote the predicate for curr's constraint set, where t represents - // the thread indices (threadIdx.x, threadIdx.y, threadIdx.z). - // - // We check bidirectional implication: - // 1. P(t) => C(t): Every thread executing prev also executes curr - // 2. C(t) => P(t): Every thread executing curr also executes prev - // - // If both hold, then P(t) <=> C(t), meaning the exact same set of threads - // execute both accesses. Combined with has_same_index (same buffer index - // expression), this guarantees each thread only accesses locations it - // wrote itself, eliminating cross-thread conflicts. PrimExpr prev_constr = prev.cset.ToConjunction(); PrimExpr curr_constr = curr.cset.ToConjunction(); @@ -1295,6 +1460,16 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { analyzer.Bind(iv->var, iv->dom); } } + // Add loop variable constraint for loop-carry analysis + if (loop != nullptr) { + // For loop-carry analysis, we compare iteration i with iteration i+1. + // Since i+1 must be a valid iteration, i can only range from min to + // min+extent-2 (i.e., extent-1 valid pairs instead of extent). + PrimExpr adjusted_extent = + loop->extent - make_const(loop->extent.dtype(), 1); + analyzer.Bind(loop->loop_var, + Range::FromMinExtent(loop->min, adjusted_extent)); + } // Check P => C: ¬P ∨ C bool prev_implies_curr = analyzer.z3_prover.CanProve( @@ -1312,22 +1487,39 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } } + // Indices are different, need to check if they can overlap + bool range_is_overlap = true; + for (size_t i = 0; i < prev.buffer_indices.size(); i++) { auto prev_dtype = prev.dtype; auto curr_dtype = curr.dtype; const auto &prev_indice = prev.buffer_indices[i]; - const auto &curr_indice = curr.buffer_indices[i]; + PrimExpr curr_indice = curr.buffer_indices[i]; + + // For loop-carry, shift the curr index to represent next iteration + if (loop != nullptr) { + curr_indice = Substitute(curr_indice, loop_shift_sub); + } PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); - has_same_index = false; - ConstrSet prev_cset{prev.cset}; ConstrSet curr_cset{curr.cset}; arith::Analyzer analyzer; + // Add loop variable constraint for loop-carry analysis + if (loop != nullptr) { + // For loop-carry analysis, we compare iteration i with iteration i+1. + // Since i+1 must be a valid iteration, i can only range from min to + // min+extent-2 (i.e., extent-1 valid pairs instead of extent). + PrimExpr adjusted_extent = + loop->extent - make_const(loop->extent.dtype(), 1); + analyzer.Bind(loop->loop_var, + Range::FromMinExtent(loop->min, adjusted_extent)); + } + struct ThreadVarInfo { const char *name_prev; const char *name_curr; @@ -1360,17 +1552,13 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); // Handle Ramp expressions by creating a new index variable - // Check if prev_indice_bytes is a Ramp expression if (const RampNode *prev_ramp = prev_indice_bytes.as()) { - // Create index variable for prev Ramp Var prev_idx("prev_idx", DataType::Int(32)); analyzer.Bind(prev_idx, Range::FromMinExtent(0, prev_ramp->lanes)); prev_indice_bytes = prev_ramp->base + prev_idx * prev_ramp->stride; } - // Check if curr_indice_bytes is a Ramp expression if (const RampNode *curr_ramp = curr_indice_bytes.as()) { - // Create index variable for curr Ramp Var curr_idx("curr_idx", DataType::Int(32)); analyzer.Bind(curr_idx, Range::FromMinExtent(0, curr_ramp->lanes)); curr_indice_bytes = curr_ramp->base + curr_idx * curr_ramp->stride; @@ -1392,10 +1580,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); provably_disjoint = analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); - if (!provably_disjoint) { - // LOG(WARNING) << analyzer.z3_prover.GetModel( - // tir::EQ(prev_indice_bytes, curr_indice_bytes)); - } } else { LOG(WARNING) << "Unscalar: " << prev_indice_bytes << "; " << curr_indice_bytes; @@ -1408,11 +1592,9 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); auto curr_max = analyzer.Simplify( Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); - // analyzer.z3_prover.SetRLimit(100000000); provably_disjoint = analyzer.CanProve(analyzer.Simplify( tir::Or(prev_min > curr_max, curr_min > prev_max))); } catch (const std::exception &e) { - // Log for debugging; fall back to conservative bound check LOG(WARNING) << "Exception in conflict detection: " << e.what(); auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); @@ -1438,23 +1620,13 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } } - // If this is a read into a double buffer that was previously - // swapped out, then it doesn't conflict. - if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { - return false; - } - - // If nothing else allows sharing the same buffer, then they are - // in conflict. - // if range_is_overlap is true, then they are in conflict, we should return - // true. if range_is_overlap is false, then they are not in conflict, we - // should return false. return range_is_overlap; } + bool FindConflict(const std::vector &prev, - const AccessEntry &curr, bool loop_carry) { + const AccessEntry &curr, const ForNode *loop) { for (const AccessEntry &x : prev) { - if (FindConflict(x, curr, loop_carry)) { + if (FindConflict(x, curr, loop)) { return true; } } @@ -1469,7 +1641,15 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) { if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) { stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); } - TileLangThreadSyncPlanner planner(sync_scope); + // Get warp size from target, defaulting to 32 if not available + int warp_size = 32; + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + warp_size = target.value() + ->GetAttr("thread_warp_size", 32) + .value() + .IntValue(); + } + TileLangThreadSyncPlanner planner(sync_scope, warp_size); for (const auto &[_, buffer] : func->buffer_map) { planner.SetBufferDataToBuffer(buffer->data, buffer); } diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 4a6b738e7..8b2901571 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -255,5 +255,349 @@ def func(): assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)") +@tilelang.testing.requires_cuda +def test_loop_carry_no_dependency_same_index(): + """Test that A[i] write followed by A[i] read in a loop does NOT need barrier. + + After iteration shift analysis: + - Iteration i writes A[i] + - Iteration i+1 reads A[i+1] (shifted from A[i]) + - A[i] vs A[i+1] are disjoint, so no loop-carried dependency + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for i in range(10): + # Each iteration writes to A[tx], then reads from A[tx] + # No loop-carried dependency because different iterations + # access different locations + temp_shared[tx] = T.float32(i) + result_local[0] = result_local[0] + temp_shared[tx] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Should NOT have sync inside the loop since A[tx] in iteration i + # does not conflict with A[tx] in iteration i+1 (they're different threads' data) + # The key insight: same thread writes and reads its own location + assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync in loop:\n{s}" + + +@tilelang.testing.requires_cuda +def test_loop_carry_with_cross_thread_dependency(): + """Test loop-carried dependency where different threads access overlapping locations. + + In this test: + - Thread tx writes to A[tx] + - Then reads from A[(tx + 127) % 128] (neighbor's data from previous iteration) + + After iteration shift analysis, we compare: + - Iteration i: thread tx writes A[tx] + - Iteration i+1: thread tx reads A[(tx + 127) % 128] + + This creates a cross-thread dependency where thread tx+1's write conflicts + with thread tx's read in the next iteration, requiring a barrier. + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for i in range(10): + # Each thread writes to its own location + temp_shared[tx] = T.float32(i) + # Then reads from neighbor (creates cross-thread dependency) + result_local[0] = result_local[0] + temp_shared[(tx + 127) % 128] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Should have sync because thread tx reads from thread (tx+127)%128's location + # This is a WAR hazard across threads + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync for cross-thread dependency:\n{s}" + + +@tilelang.testing.requires_cuda +def test_loop_carry_modulo_buffering(): + """Test that A[i%2] write followed by A[i%2] read does NOT need barrier (double buffering). + + After iteration shift analysis: + - Iteration i writes A[i%2] + - Iteration i+1 reads A[(i+1)%2] (shifted from A[i%2]) + - A[i%2] vs A[(i+1)%2] are disjoint (0 vs 1 or 1 vs 0), so no dependency + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([2, 64], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 64) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for i in range(10): + # Double buffering pattern: write to buffer[i%2], read from buffer[i%2] + # After shift: write buffer[i%2], read buffer[(i+1)%2] + # These are different buffers, so no conflict + temp_shared[i % 2, tx] = T.float32(i) + result_local[0] = result_local[0] + temp_shared[i % 2, tx] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Should NOT have sync inside loop due to modulo buffering analysis + # Note: This test verifies the modulo analysis capability + print(f"Modulo buffering result:\n{s}") + + +@tilelang.testing.requires_cuda +def test_loop_carry_different_indices(): + """Test that A[i] write followed by A[i+1] read does NOT need barrier. + + After iteration shift analysis: + - Iteration i writes A[i] + - Iteration i+1 reads A[i+2] (shifted from A[i+1], becomes A[(i+1)+1] = A[i+2]) + - A[i] vs A[i+2] are disjoint, so no loop-carried dependency + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 1) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for i in range(10): + # Write to A[i], read from A[i+1] + # After shift: comparing A[i] (write) vs A[i+2] (read from i+1 shifted) + # No overlap, no dependency + temp_shared[i] = T.float32(i) + result_local[0] = result_local[0] + temp_shared[i + 1] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + print(f"Different indices result:\n{s}") + + +# ============================================================================= +# Tests for non-uniform if condition sync hoisting +# ============================================================================= + + +@tilelang.testing.requires_cuda +def test_sync_hoist_non_uniform_if_with_threadidx(): + """Test that sync is hoisted when if condition directly depends on threadIdx. + + When the if condition uses threadIdx, different threads may take different + branches. If a sync is needed inside the if, it must be hoisted to before + the if statement to avoid deadlock. + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + # First, all threads write to shared memory + temp_shared[tx] = T.float32(tx) + # Non-uniform condition: only some threads enter the if + if tx < 64: + # Inside the if, we read from shared memory + # This needs a sync, but since condition is non-uniform, + # the sync must be hoisted to before the if + result_local[0] = temp_shared[tx + 64] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Sync should appear before the if statement + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + # The sync should be before the if, not inside it + sync_pos = s.index('T.tvm_storage_sync("shared")') + if_pos = s.index("if tx < 64") + assert sync_pos < if_pos, f"Sync should be before if statement:\n{s}" + + +@tilelang.testing.requires_cuda +def test_sync_hoist_non_uniform_if_shared_memory_condition(): + """Test sync hoisting when if condition reads from shared memory with thread-dependent index. + + This is the exact pattern that caused the original deadlock: + - Condition reads shared memory at index depending on threadIdx + - Different threads get different values -> non-uniform condition + - Sync inside if would cause deadlock + """ + + @T.prim_func(private=True) + def func(): + token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") + data_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + # First phase: all threads write to data_shared + data_shared[tx] = T.float32(tx) + # Non-uniform condition: reads shared memory with threadIdx-dependent index + # token_ids[tx] can be different for each thread (e.g., some are -1, some are valid) + if token_ids[tx] != -1: + # Inside the if, we read from data_shared + # Sync is needed but must be hoisted because condition is non-uniform + result_local[0] = data_shared[tx] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Sync should appear before the if statement + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + # The sync should be before the if that checks token_ids + sync_pos = s.index('T.tvm_storage_sync("shared")') + if_pos = s.index("if token_ids") + assert sync_pos < if_pos, f"Sync should be hoisted before non-uniform if:\n{s}" + + +@tilelang.testing.requires_cuda +def test_sync_inside_uniform_if_blockidx(): + """Test that sync can stay inside if when condition is uniform (blockIdx). + + When the if condition only depends on blockIdx (same for all threads in a block), + all threads take the same branch, so sync inside the if is safe. + """ + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 4) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + # First, all threads write to shared memory + temp_shared[tx] = T.float32(tx) + # Uniform condition: blockIdx is same for all threads in a block + if bx < 2: + # Sync inside uniform if is safe - all threads in this block + # will either all enter or all skip this branch + result_local[0] = temp_shared[(tx + 64) % 128] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # Should have sync (either inside or outside the if is fine for uniform condition) + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + + +@tilelang.testing.requires_cuda +def test_sync_hoist_nested_non_uniform_if(): + """Test sync hoisting with nested if statements where outer is non-uniform.""" + + @T.prim_func(private=True) + def func(): + temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + # Write to shared memory + temp_shared[tx] = T.float32(tx) + # Outer non-uniform condition + if tx < 64: + # Inner condition (also non-uniform) + if tx < 32: + # Sync needed here must be hoisted all the way out + result_local[0] = temp_shared[tx + 64] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + # Sync should be before the outermost non-uniform if + sync_pos = s.index('T.tvm_storage_sync("shared")') + if_pos = s.index("if tx < 64") + assert sync_pos < if_pos, f"Sync should be hoisted before outer if:\n{s}" + + +@tilelang.testing.requires_cuda +def test_sync_hoist_non_uniform_if_in_loop(): + """Test sync hoisting when non-uniform if is inside a loop.""" + + @T.prim_func(private=True) + def func(): + token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") + data_shared = T.alloc_buffer([128], dtype="float32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for k in range(2): + # Write to shared memory + data_shared[tx] = T.float32(tx + k) + # Non-uniform if inside loop + if token_ids[tx] != -1: + result_local[0] = result_local[0] + data_shared[tx] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + # Sync should be before the if inside the loop, not inside the if + # This ensures all threads can reach the sync point + + +@tilelang.testing.requires_cuda +def test_no_sync_needed_uniform_accesses(): + """Test that no extra sync is added when accesses are already safe. + + When each thread only accesses its own data (no cross-thread dependency), + no sync is needed even inside an if statement. + """ + + @T.prim_func(private=True) + def func(): + temp_local = T.alloc_buffer([1], dtype="float32", scope="local") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + temp_local[0] = T.float32(tx) + # Non-uniform condition but no shared memory access + if tx < 64: + result_local[0] = temp_local[0] + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + # No sync needed - only local memory is accessed + assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}" + + if __name__ == "__main__": tilelang.testing.main()