diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 6e7321452..a93e4de13 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -13,160 +12,159 @@ from heuristic import num_splits_heuristic -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, page_block_size, num_stages, threads, num_pages): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func( - block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + num_split = T.dynamic("num_split") + max_num_blocks_per_seq = T.dynamic("max_num_blocks_per_seq") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - shape_q = [batch, heads, dim] - shape_k = [num_pages, page_block_size, heads_kv, dim] - shape_v = [num_pages, page_block_size, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_block_table = [batch, max_num_blocks_per_seq] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - assert block_N <= page_block_size and page_block_size % block_N == 0 - block_ratio = page_block_size // block_N - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - block_table: T.Tensor(shape_block_table, T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - logical_block_idx = block_indices[bid, cur_kv_head, start + k] - if logical_block_idx >= 0: - has_valid_block = True - block_table_idx = T.floordiv(logical_block_idx, block_ratio) - block_tile_idx = T.floormod(logical_block_idx, block_ratio) - physical_block_idx = block_table[bid, block_table_idx] - T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + print(main) + return main class SparseFlashAttn(torch.nn.Module): @@ -181,19 +179,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, self.page_block_size = page_block_size self.num_pages = num_pages self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_N, - block_H=self.block_H, - page_block_size=page_block_size, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - num_pages=num_pages, - max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -221,16 +206,19 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel( - query, - key, - value, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + page_block_size=self.page_block_size, + num_stages=2, + threads=128, + num_pages=self.num_pages, + )(query, key, value, block_indices, cache_seqlens, block_table, glse, output_partial) return output @@ -513,6 +501,8 @@ def main(args): def run_regression_perf(args): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( args.batch, args.heads, @@ -524,15 +514,15 @@ def run_regression_perf(args): sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size - num_blocks = args.num_pages + num_pages = args.num_pages max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") @@ -596,22 +586,20 @@ def run_regression_perf(args): for i in range(len(selected_blocks), max_selected_blocks): block_indices[seq_idx, head_idx, i] = -1 - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) - kernel = sparse_attn.kernel - batch = sparse_attn.batch - heads = sparse_attn.heads - heads_kv = sparse_attn.heads_kv - dim_v = sparse_attn.dim_v - dim = sparse_attn.dim - block_size = sparse_attn.block_N + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_N max_selected_blocks = block_indices.shape[-1] - num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - - num_sm = sparse_attn.num_sm + num_sm = sparse_kernel.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 @@ -619,18 +607,22 @@ def run_regression_perf(args): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + page_block_size=sparse_kernel.page_block_size, + num_stages=2, + threads=128, + num_pages=sparse_kernel.num_pages, + ) def run_kernel_only(): - kernel( - Q, - K_cache, - V_cache, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + kernel(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index d6cf7d917..f432fe0fa 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -10,153 +10,150 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - # actual_num_blocks: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - - for k in T.Pipelined(loop_range, num_stages=num_stages): - i_s = block_indices[bid, cur_kv_head, start + k] - if i_s >= 0: - has_valid_block = True - T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + return main class SparseFlashAttn(torch.nn.Module): @@ -168,19 +165,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -208,7 +193,17 @@ def forward(self, query, key, value, block_indices, cache_seqlens): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output @@ -252,14 +247,16 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) @@ -311,7 +308,7 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: @@ -324,29 +321,17 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + dtype = torch.float16 sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) - print("max_selected_blocks: ", max_selected_blocks) - dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') - # # Ensure at least one element equals cache_seqlen - # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) - # Initialize block_indices with -1 (for padding blocks) block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") - # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) - # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -354,27 +339,17 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] - # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] block_indices[b, h, : len(valid_indices)] = valid_indices - # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) # parity reference ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) - - import flash_attn # noqa: F401 + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) ## latency reference for _ in range(10): @@ -387,12 +362,10 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 print("dense time: ", (time.time() - start) / 100 * 1000) for _ in range(10): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() start = time.time() for _ in range(100): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index e48428fb8..e588ec54c 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -11,137 +10,144 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_mask = [batch, heads_kv, num_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, T.bool), - cache_seqlens: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[bid, hid, start + k]: - has_valid_block = True - T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + num_blocks = T.dynamic("num_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else((start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] - return kernel_func + return main class SparseFlashAttn(torch.nn.Module): @@ -153,19 +159,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -176,24 +170,33 @@ def forward(self, query, key, value, block_mask, cache_seqlens): dim_v = self.dim_v dim = self.dim block_size = self.block_size - block_H = self.block_H max_cache_seqlen = key.shape[1] # get num_split max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - # print("num_split: ", num_split) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_mask, cache_seqlens, glse, output_partial) return output @@ -233,21 +236,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), ) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) - return output @@ -297,12 +300,10 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: - # print(expect[3, 28]) - # print(actual[3, 28]) diff = (expect - actual).abs() print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) @@ -353,7 +354,7 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) import flash_attn # noqa: F401 @@ -381,12 +382,13 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") @@ -408,31 +410,41 @@ def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, di perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) - batch = model.batch - heads = model.heads - heads_kv = model.heads_kv - dim_v = model.dim_v - dim = model.dim - block_size = model.block_size - block_H = model.block_H - max_cache_seqlen = K.shape[1] + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - num_sm = model.num_sm + num_sm = sparse_kernel.num_sm + num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = model.kernel + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + num_stages=2, + threads=128, + ) def run_kernel_only(): - kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + kernel(Q, K, V, block_mask, cache_seqlens, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 538f59fa9..8c726353c 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -15,7 +15,6 @@ #include "../target/utils.h" #include "../transform/atomicadd_vectorize.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" @@ -658,8 +657,6 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto transformed_loop = - Downcast(ParallelLoopTransformer::Substitute(fused_loop)); auto GetArchInt = [&](const Target &tgt) -> int { int arch_int = 0; @@ -785,12 +782,12 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return {loop_layout, pred}; }; - auto ret = AtomicAddInferLayout(transformed_loop, - {T.target, T.thread_bounds, T.layout_map, - analyzer, false, T.buffer_remap}); + auto ret = + AtomicAddInferLayout(fused_loop, {T.target, T.thread_bounds, T.layout_map, + analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = - PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); + PartitionLoop(fused_loop, T.thread_var, analyzer, loop_layout); auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, GetArchInt(target)); return vectorized_thread_loop; diff --git a/src/op/copy.cc b/src/op/copy.cc index 7f91d4c38..981fd1a27 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -12,7 +12,6 @@ #include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "utils.h" @@ -716,11 +715,8 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto transformed_loop = - Downcast(ParallelLoopTransformer::Substitute(fused_loop)); - For vectorized_thread_loop; - auto par_op = ParallelOp(transformed_loop); + auto par_op = ParallelOp(fused_loop); if (is_cpu_target || IsLocalBuffer(src) || IsLocalBuffer(dst)) { if (IsLocalBuffer(src) && !IsLocalBuffer(dst)) { @@ -728,7 +724,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, << dst.scope() << " buffer `" << dst->name << "` may cause conflicted write."; } - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(fused_loop); return vectorized_thread_loop; } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, diff --git a/src/op/fill.cc b/src/op/fill.cc index 02962d242..6a1768668 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -13,7 +13,6 @@ #include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "builtin.h" diff --git a/src/transform/common/loop_parallel_transform_utils.h b/src/transform/common/loop_parallel_transform_utils.h deleted file mode 100644 index 52a5a9b97..000000000 --- a/src/transform/common/loop_parallel_transform_utils.h +++ /dev/null @@ -1,170 +0,0 @@ -/*! - * \file common.h - * \brief Common utilities for TL transforms - */ - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "arith/ir_mutator_with_analyzer.h" -#include "arith/ir_visitor_with_analyzer.h" -#include - -#include "../../op/utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; -using arith::IRMutatorWithAnalyzer; -using arith::IRVisitorWithAnalyzer; - -class ParallelLoopTransformer : public IRMutatorWithAnalyzer { -public: - static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { - arith::Analyzer analyzer; - ParallelLoopTransformer transformer(&analyzer); - return transformer.VisitStmt(stmt); - } - - ParallelLoopTransformer(arith::Analyzer *analyzer) - : IRMutatorWithAnalyzer(analyzer) {} - - Stmt VisitStmt_(const ForNode *op) final { - - if (op->kind != ForKind::kParallel) - return StmtMutator::VisitStmt_(op); - - // Collect loop variables and ranges - auto for_node = tvm::ffi::GetRef(op); - Array loop_vars; - Array loop_extents; - Stmt body = op->body; - - // Bind the range of outer loop variables - analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent)); - loop_vars.push_back(op->loop_var); - loop_extents.push_back(op->extent); - - // If there are inner loops, bind their ranges as well - while (const ForNode *inner = body.as()) { - analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent)); - loop_vars.push_back(inner->loop_var); - loop_extents.push_back(inner->extent); - body = inner->body; - } - - ICHECK(loop_vars.size() == loop_extents.size()) - << "loop_vars and loop_extents size mismatch"; - - // Collect buffer access information - BufferAccessCollector collector; - collector(op->body); - - PrimExpr condition; - - for (const auto &[buffer, indices] : collector.buffer_indices) { - ICHECK(indices.size() == buffer->shape.size()) - << "indices size mismatch with buffer shape"; - - for (size_t i = 0; i < indices.size(); ++i) { - auto index = indices[i]; - auto bound = analyzer_->const_int_bound(index); - - // Collect the variables that used in the index - std::unordered_set used_vars; - // post order visit the index - PostOrderVisit(index, [&](const ObjectRef &obj) { - if (const VarNode *v = obj.as()) { - used_vars.insert(tvm::ffi::GetRef(v)); - } - }); - if (used_vars.empty()) { - continue; - } - - // find related loop vars - Array related_loop_vars; - for (size_t j = 0; j < loop_vars.size(); ++j) { - auto loop_var = loop_vars[j]; - // if find related, pop the loop_vars and loop_extents - if (used_vars.count(loop_var)) { - related_loop_vars.push_back(loop_var); - } - if (related_loop_vars.size() > 1) { - // Only one related loop var is supported transformation currently. - return for_node; - } - - auto bound = analyzer_->const_int_bound(index); - int64_t upper_bound = bound->max_value + 1; - int64_t shape = Downcast(buffer->shape[i])->value; - if (upper_bound < shape) { - PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound)); - condition = - condition.defined() ? And(condition, predicate) : predicate; - } - } - } - } - - if (condition.defined()) { - body = IfThenElse(condition, body); - - for (int j = loop_vars.size() - 1; j >= 0; --j) { - auto loop_var = loop_vars[j]; - auto loop_extent = loop_extents[j]; - body = For(loop_var, 0, loop_extent, ForKind::kParallel, body); - } - - return Downcast(body); - } - - // Only traverse the outer loop - return for_node; - } - - // Helper class for collecting buffer access information, only counts fragment - // buffer access - class BufferAccessCollector : public StmtExprVisitor { - public: - void VisitExpr_(const BufferLoadNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - if (buffer_indices.find(op->buffer) == buffer_indices.end()) { - buffer_indices[op->buffer] = op->indices; - } else { - // check equal - ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) - << "indices mismatch for buffer: " << op->buffer; - } - } - StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const BufferStoreNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - if (buffer_indices.find(op->buffer) == buffer_indices.end()) { - buffer_indices[op->buffer] = op->indices; - } else { - // check equal - ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) - << "indices mismatch for buffer: " << op->buffer; - } - } - StmtExprVisitor::VisitStmt_(op); - } - - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> - buffer_indices; - }; -}; - -} // namespace tl -} // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 15d7f71e2..a622d71f4 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -26,7 +26,6 @@ #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_fusion_utils.h" -#include "common/loop_parallel_transform_utils.h" #include "common/union_find.h" #include "layout_reducer.h" #include "parallel_loop_layout_validator.h" @@ -1253,7 +1252,6 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { tvm::transform::Pass LayoutInference() { using namespace tir::transform; auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { - f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body); ThreadBindingCollector collector; collector(f->body); bool has_thread_binding = !collector.thread_binding_.empty();