diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index e645ae147..0018e9c93 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -138,22 +138,21 @@ def main( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) ) for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[k]: + if block_mask[k] != 0: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index 2f53499e2..7422105a5 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -321,16 +321,15 @@ def flash_bwd_dsink( dsinks: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], dtype) - sink[0] = Sinks[bx] + sink = Sinks[bx] T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index a943be25f..c16195e87 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -327,16 +327,15 @@ def flash_bwd_dsink( dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): - sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], accum_dtype) - sink[0] = Sinks[bx] + sink = Sinks[bx] T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i] T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index ca58a8217..88ec23172 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -137,15 +137,14 @@ def blocksparse_flashattn( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = ( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 7154a362e..3556bcee4 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -136,40 +136,34 @@ def combine( with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], T.int32) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if lse_local_split[0] != 0: - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): - if k <= max_split[0]: + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index 591560e4c..f2abe6194 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -125,40 +125,34 @@ def combine( with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], T.int32) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_local_split[0] = glse[bz, by, k] - if lse_local_split[0] != 0: - max_split[0] = k - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split[0]: - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): - if k <= max_split[0]: + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 75b60b9d5..112eddcef 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -121,33 +121,27 @@ def combine( with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim_v): Output[bz, by, i] = o_accum_local[i] diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index a9035793b..9f44db83b 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -173,31 +173,25 @@ def combine( with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): o_accum_local[i] += po_local[i] * scale_local[0] for i in T.Parallel(dim): diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 6d3e659b0..7fab4c062 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -167,33 +167,27 @@ def combine( with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim): Output[bz, hid, i] = o_accum_local[i] diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index 23001bde8..da22ada96 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -188,33 +188,27 @@ def combine( with T.Kernel(h_q, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dv], dtype) o_accum_local = T.alloc_fragment([dv], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dv): po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dv): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dv): Output[bz, by, i] = o_accum_local[i] diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index b6a1300a2..b5e3c9f7c 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -50,16 +50,14 @@ def main_split_persistent( logsum = T.alloc_fragment([block_H], accum_dtype) po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.annotate_layout( { - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), S_shared: tilelang.layout.make_swizzled_layout(S_shared), - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), } ) T.use_swizzle(10) @@ -123,20 +121,20 @@ def main_split_persistent( if bid < batch and hid < heads: T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k]) + lse_max_local = T.max(lse_max_local, glse[bid, hid, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bid, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bid, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bid, hid, k, i] - lse_local_split[0] = glse[bid, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bid, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim): Output[bid, hid, i] = o_accum_local[i] diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 8e317fa00..66651f7dc 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -472,31 +472,25 @@ def combine( with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_local([1], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout( - { - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - } - ) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) T.clear(lse_logsum_local) T.clear(o_accum_local) - lse_max_local[0] = -T.infinity(accum_dtype) + lse_max_local = -T.infinity(accum_dtype) for k in T.serial(num_split): - lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + lse_max_local = T.max(lse_max_local, glse[bz, hid, k]) for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, hid, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + lse_local_split = glse[bz, hid, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local for k in T.serial(num_split): for i in T.Parallel(dim): po_local[i] = Output_partial[bz, hid, k, i] - lse_local_split[0] = glse[bz, hid, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + lse_local_split = glse[bz, hid, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): o_accum_local[i] += po_local[i] * scale_local[0] for i in T.Parallel(dim): diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index edb401d93..c91996f63 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -130,23 +130,23 @@ def mqa_attn_return_logits_kernel( seq_len_i = bx * block_Q - cu_k_s_min = T.alloc_local([1], index_dtype) - cu_k_e_max = T.alloc_local([1], index_dtype) + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) - cu_k_s_min[0] = 2147483647 - cu_k_e_max[0] = -2147483648 + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) - for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): - T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) - T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): + T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) T.gemm( index_k_shared, @@ -165,7 +165,7 @@ def mqa_attn_return_logits_kernel( T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] return mqa_attn_return_logits_kernel @@ -189,15 +189,13 @@ def clean_logits_kernel( ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") - cu_k_s = T.alloc_local([1], indices_dtype) - cu_k_e = T.alloc_local([1], indices_dtype) - cu_k_s[0] = CuSeqLenKS[bx] - cu_k_e[0] = CuSeqLenKE[bx] + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): for k_i in T.serial(block_K // threads): idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s[0] or idx >= cu_k_e[0]: + if idx < cu_k_s or idx >= cu_k_e: Logits[bx, idx] = -T.infinity(dtype) return clean_logits_kernel diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 63763bfd4..8c29990fa 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -112,7 +112,7 @@ def main( alpha_local = T.alloc_fragment([H_per_block], accum_dtype) m_i = T.alloc_fragment([H_per_block], accum_dtype) m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - indices_local = T.alloc_local([1], indices_dtype) + indices_local = T.alloc_var(indices_dtype) # TODO: Multi buffer bar_q = T.alloc_barrier(arrive_count=384) @@ -263,44 +263,44 @@ def main( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v ] KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + indices_local = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v ] KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index 94043c23c..3a696836c 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -155,18 +155,15 @@ def kernel( gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], T.int32) - cur_group_size = T.alloc_local([1], T.int32) - T.use_swizzle(10, enable=True) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] - actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(gate_logits_local) T.clear(up_logits_local) @@ -179,7 +176,7 @@ def kernel( ) T.copy( routed_expert_gate[ - cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden ], routed_expert_gate_shared, coalesced_width=coalesced_width, @@ -187,7 +184,7 @@ def kernel( T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) T.copy( routed_expert_up[ - cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden ], routed_expert_up_shared, coalesced_width=coalesced_width, @@ -208,18 +205,15 @@ def kernel( routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], T.int32) - cur_group_size = T.alloc_local([1], T.int32) - T.use_swizzle(10, enable=True) m_start_padded = bx * block_token - cur_group_idx[0] = group_idx_for_bx[bx] + cur_group_idx = group_idx_for_bx[bx] - cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] - actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + cur_group_size = group_sizes[cur_group_idx] + m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx] + actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx]))) T.clear(output_local) @@ -231,7 +225,7 @@ def kernel( ) T.copy( routed_expert_down[ - cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert ], routed_expert_down_shared, coalesced_width=coalesced_width, diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index 39450bc5f..c3c2aabc5 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -259,8 +259,6 @@ def kernel( Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32) W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) - G_last_local_exp = T.alloc_local((1), dtype=gate_dtype) G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared") G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype) G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype) @@ -305,17 +303,14 @@ def kernel( if use_g: T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) T.copy(G_shared, G_fragment) - G_last_local[0] = G_shared[block_S - 1] - G_last_local_exp[0] = T.exp(G_last_local[0]) + G_last_local = G_shared[block_S - 1] + G_last_local_exp = T.exp(G_last_local) for i_s2 in T.Parallel(block_S): - G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2]) + G_fragment_post[i_s2] = T.exp(G_last_local - G_fragment[i_s2]) for i_s2, i_v in T.Parallel(block_S, block_DV): - # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): - with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): - with T.Then(): - dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] - with T.Else(): - dv_fragment[i_s2, i_v] = 0 + dv_fragment[i_s2, i_v] = ( + dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] if G_last_local - G_fragment[i_s2] <= 0 else 0 + ) T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) T.copy(dv_shared, dv_fragment_2) @@ -333,12 +328,11 @@ def kernel( T.clear(Q_fragment) if use_g: for i_k, i_v in T.Parallel(DK, block_DV): - b_dh_fragment[i_k, i_v] *= G_last_local_exp[0] + b_dh_fragment[i_k, i_v] *= G_last_local_exp T.copy(Q_shared, Q_fragment) for i_s2 in T.Parallel(block_S): G_fragment_exp[i_s2] = T.exp(G_shared[i_s2]) for i_s2, i_k in T.Parallel(block_S, DK): - # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale else: T.copy(Q_shared, Q_fragment) diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index d316a6211..f11fc4dd3 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -156,7 +156,7 @@ def kernel( V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_last_local = T.alloc_var(T.float32) G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) @@ -201,21 +201,19 @@ def kernel( T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) # use_g if use_g: - G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] + G_last_local = G[bb, (i_s + 1) * block_S - 1, bh] for i_s2, i_v in T.Parallel(block_S, block_DV): G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh] T.copy(G_shared, G_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): - with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): - with T.Then(): - V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( - (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695 - ) - with T.Else(): - V_new_fragment[i_s2, i_v] = 0 - G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) + V_new_fragment[i_s2, i_v] = ( + V_new_fragment[i_s2, i_v] * T.exp2((G_last_local - G_fragment[i_s2, i_v]) * 1.442695) + if G_last_local - G_fragment[i_s2, i_v] <= 0 + else 0 + ) + G_last_local = T.exp2(G_last_local * 1.442695) for i_k, i_v in T.Parallel(DK, block_DV): - b_h_fragment[i_k, i_v] *= G_last_local[0] + b_h_fragment[i_k, i_v] *= G_last_local # Update intermediate results T.copy(V_new_fragment, V_new_shared) diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 97e2f4f01..f11cbebe1 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -199,13 +199,15 @@ def kernel( dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype) dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype) - dg_last_local = T.alloc_local((2,), dtype=gate_dtype) + dg_last_local_0 = T.alloc_var(dtype=gate_dtype) + dg_last_local_1 = T.alloc_var(dtype=gate_dtype) + G_last_local = T.alloc_var(dtype=gate_dtype) + dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype) dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype) dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype) dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype) - G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared") - G_last_local = T.alloc_local((1,), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) T.use_swizzle(10) @@ -221,7 +223,8 @@ def kernel( } ) - T.clear(dg_last_local) + T.clear(dg_last_local_0) + T.clear(dg_last_local_1) T.clear(G_last_local) T.clear(G_shared) T.clear(q_fragment) @@ -247,7 +250,7 @@ def kernel( for i_kv in T.Parallel(block_DK * block_DV): dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) - dg_last_local[0] += dg_last_fragment_scalar[0] + dg_last_local_0 = dg_last_local_0 + dg_last_fragment_scalar[0] T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True) T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True) @@ -272,9 +275,9 @@ def kernel( T.clear(dg_fragment_2) for i_s, i_k in T.Parallel(block_S, block_DK): G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh] - G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh] + dg_last_local_0 = G[bb, bs * block_S + block_S - 1, bh] # Use gmem directly instead of local register - dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) + dg_last_local_0 = dg_last_local_0 * T.exp(G[bb, bs * block_S + block_S - 1, bh]) for i_s, i_k in T.Parallel(block_S, block_DK): dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale @@ -285,11 +288,11 @@ def kernel( T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) for i_s, i_k in T.Parallel(block_S, block_DK): - with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): - with T.Then(): - dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh]) - with T.Else(): - dk_fragment[i_s, i_k] = 0 + dk_fragment[i_s, i_k] = ( + dk_fragment[i_s, i_k] * T.exp(G_last_local - G[bb, bs * block_S + i_s, bh]) + if G_last_local - G[bb, bs * block_S + i_s, bh] <= 0 + else 0 + ) T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k]) @@ -303,16 +306,14 @@ def kernel( i_s, i_k = i_sk // block_DK, i_sk % block_DK dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k] T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False) - dg_last_local[1] = dg_last_fragment_scalar_2[0] + dg_last_local_1 = dg_last_fragment_scalar_2[0] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - ds_fragment[i_s1, i_s2] = ( - ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale - ) - with T.Else(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = ( + (ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale) + if G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0 + else 0 + ) T.clear(ds_fragment_positive) T.clear(ds_fragment_positive_transpose) @@ -340,9 +341,7 @@ def kernel( T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True) for i_s in T.Parallel(block_S): - with T.If(i_s >= block_S - 1): # noqa: SIM117 - with T.Then(): - dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local_0 + dg_last_local_1 T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) @@ -351,9 +350,7 @@ def kernel( else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - ds_fragment[i_s1, i_s2] = 0 + ds_fragment[i_s1, i_s2] = 0 if i_s1 < i_s2 else ds_fragment[i_s1, i_s2] T.clear(dk_fragment_2) T.copy(ds_fragment, ds_shared) T.gemm(ds_shared, k_shared, dq_fragment) diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index e749fa087..ab6389eaa 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -1,5 +1,4 @@ import torch -import tilelang.testing from tilelang import language as T B = 1 @@ -317,4 +316,5 @@ def test_example_chunk_delta_bwd_compilation(): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_example_chunk_delta_bwd_compilation() diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index 7b93f2a77..5806caff6 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -258,28 +258,28 @@ def kernel( T.clear(A_sp_shared) T.clear(E_shared) # TODO: alloc_var seems buggy here - non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) - non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) + non_zero_cnt = T.alloc_var(dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_shared((elem,), dtype=T.uint8) T.copy(A[bx * block_M, by * block_K], A_shared) for tm in T.Parallel(block_M): for g_i in range(0, block_K // group): a_k = g_i * group - non_zero_cnt[0] = 0 + non_zero_cnt = 0 for i in range(elem): non_zero_elt_log_idx[i] = 0 for i in range(group): val = A_shared[tm, a_k + i] if val != 0.0: - non_zero_elt_log_idx[non_zero_cnt[0]] = i - A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val - non_zero_cnt[0] += 1 + non_zero_elt_log_idx[non_zero_cnt] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt] = val + non_zero_cnt += 1 # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main - if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + if non_zero_cnt == 1 and non_zero_elt_log_idx[0] == 3: non_zero_elt_log_idx[0] = 0 non_zero_elt_log_idx[1] = 3 A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] A_sp_shared[tm, a_k // 2] = 0.0 - elif non_zero_cnt[0] == 1: + elif non_zero_cnt == 1: A_sp_shared[tm, a_k // 2 + 1] = 0 non_zero_elt_log_idx[1] = 3 for i in T.serial(elem): diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index bb57c6073..49cce0d1d 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -27,27 +27,27 @@ def kernel( A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], T.int32) - cur_batch_size = T.alloc_local([1], T.int32) + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] - actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) - T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 48d916051..b71472741 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -61,27 +61,27 @@ def kernel( A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], T.int32) - cur_batch_size = T.alloc_local([1], T.int32) + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_batch_size = T.alloc_var(dtype=T.int32) m_start_padded = bx * block_M for i in range(batch_count): in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] - cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) - cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] - actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + cur_batch_size = batch_sizes[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max(0, T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) - T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.copy(B[cur_batch_idx, k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: C[m_start + i, by * block_N + j] = C_local[i, j] return kernel diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index c6e3dfdb0..3bd7e4739 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -130,9 +130,9 @@ def vs_sparse_flashattn_ws( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_count = T.alloc_local([1], int_dtype) + block_count = T.alloc_var(dtype=int_dtype) block_offset = T.alloc_shared([slash_size_round], int_dtype, scope="shared") - column_count = T.alloc_local([1], int_dtype) + column_count = T.alloc_var(dtype=int_dtype) column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared") T.create_list_of_mbarrier([128] * 9) @@ -143,8 +143,8 @@ def vs_sparse_flashattn_ws( } ) - block_count[0] = BlockCount[bz, by, bx] - column_count[0] = ColumnCount[bz, by, bx] + block_count = BlockCount[bz, by, bx] + column_count = ColumnCount[bz, by, bx] for vi in T.Parallel(slash_size_round): if vi < slash_size: @@ -160,7 +160,7 @@ def vs_sparse_flashattn_ws( T.annotate_producer_reg_dealloc() T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.mbarrier_arrive(mbarrier=8) - for bi in T.serial(block_count[0]): + for bi in T.serial(block_count): k = block_offset[bi] T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) @@ -174,7 +174,7 @@ def vs_sparse_flashattn_ws( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.mbarrier_wait_parity(mbarrier=8, parity=0) - for bi in T.serial(block_count[0]): + for bi in T.serial(block_count): k = block_offset[bi] for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) @@ -207,12 +207,12 @@ def vs_sparse_flashattn_ws( for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - if column_count[0] != 0: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) - for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): + if column_count != 0: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count, 0, bz, by) + for bi in T.serial(T.ceildiv(column_count, block_N) - 1): k = bi * block_N if bi % 2 == 0: - Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by) + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count, k + block_N, bz, by) Compute( acc_s, @@ -221,7 +221,7 @@ def vs_sparse_flashattn_ws( scores_max, scores_max_prev, k, - column_count[0], + column_count, Q_shared, K_shared_1, V_shared_1, @@ -231,7 +231,7 @@ def vs_sparse_flashattn_ws( 1, ) else: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count, k + block_N, bz, by) Compute( acc_s, @@ -240,7 +240,7 @@ def vs_sparse_flashattn_ws( scores_max, scores_max_prev, k, - column_count[0], + column_count, Q_shared, K_shared_2, V_shared_2, @@ -249,15 +249,15 @@ def vs_sparse_flashattn_ws( logsum, 1, ) - if T.ceildiv(column_count[0], block_N) % 2 == 0: + if T.ceildiv(column_count, block_N) % 2 == 0: Compute( acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], + T.ceildiv(column_count, block_N) * block_N - block_N, + column_count, Q_shared, K_shared_2, V_shared_2, @@ -273,8 +273,8 @@ def vs_sparse_flashattn_ws( acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], + T.ceildiv(column_count, block_N) * block_N - block_N, + column_count, Q_shared, K_shared_1, V_shared_1, diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index 6b27b9d0b..9e9141d15 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -108,15 +108,14 @@ def main( scores_scale = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - block_mask = T.alloc_local([downsample_len], block_mask_dtype) + block_mask = T.alloc_fragment([downsample_len], block_mask_dtype) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - for vj in T.serial(downsample_len): - block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + T.copy(BlockSparseMask[bz, by, bx, :], block_mask) loop_range = T.ceildiv(seq_kv, block_N) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 6fa0c6b53..5925d38fd 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -325,7 +325,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (T.layout_map.count(src) && T.layout_map.count(dst)) { - if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { + if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { const FragmentNode *src_layout = T.layout_map[src].as(); const FragmentNode *dst_layout = T.layout_map[dst].as(); if (src_layout && dst_layout) { @@ -431,14 +431,14 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const BufferStoreNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { indice_map.Set(op->buffer, op->indices); writes.insert(op->buffer); } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { indice_map.Set(op->buffer, op->indices); } StmtExprVisitor::VisitExpr_(op); @@ -473,7 +473,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int best_rank = -1; for (auto kv : C.indice_map) { const Buffer &buf = kv.first; - if (buf.scope() != "local.fragment") + if (!IsFragmentBuffer(buf)) continue; if (!args.layout_map.count(buf)) continue; diff --git a/src/op/copy.cc b/src/op/copy.cc index 066a09b10..1fb4f5743 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -776,7 +776,7 @@ bool CopyNode::CheckBulkStore(Target target, arith::Analyzer *analyzer, bool CopyNode::CheckLDSMCopy(Target target) const { return TargetHasLdmatrix(target) && (src.scope() == "shared.dyn" || src.scope() == "shared") && - dst.scope() == "local.fragment"; + IsFragmentBuffer(dst); } /** @@ -791,7 +791,7 @@ bool CopyNode::CheckLDSMCopy(Target target) const { * otherwise. */ bool CopyNode::CheckSTSMCopy(Target target) const { - return TargetHasStmatrix(target) && src.scope() == "local.fragment" && + return TargetHasStmatrix(target) && IsFragmentBuffer(src) && (dst.scope() == "shared.dyn" || dst.scope() == "shared"); } @@ -807,7 +807,7 @@ bool CopyNode::CheckSTSMCopy(Target target) const { */ bool CopyNode::CheckTMemLoad(Target target) const { return TargetHasTmem(target) && src.scope() == "shared.tmem" && - dst.scope() == "local.fragment"; + IsFragmentBuffer(dst); } /** @@ -821,7 +821,7 @@ bool CopyNode::CheckTMemLoad(Target target) const { * otherwise. */ bool CopyNode::CheckTMemStore(Target target) const { - return TargetHasTmem(target) && src.scope() == "local.fragment" && + return TargetHasTmem(target) && IsFragmentBuffer(src) && dst.scope() == "shared.tmem"; } @@ -950,8 +950,8 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, For vectorized_thread_loop; auto par_op = ParallelOp(transformed_loop); - if (is_cpu_target || dst.scope() == "local" || src.scope() == "local") { - if (src.scope() == "local" && dst.scope() != "local") { + if (is_cpu_target || IsLocalBuffer(src) || IsLocalBuffer(dst)) { + if (IsLocalBuffer(src) && !IsLocalBuffer(dst)) { LOG(WARNING) << "Copy from local buffer `" << src->name << "` to " << dst.scope() << " buffer `" << dst->name << "` may cause conflicted write."; @@ -1231,9 +1231,9 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, bool dst_needs_unpack = 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st - if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { + if (src.scope() == "shared.tmem" && IsFragmentBuffer(dst)) { is_ld = true; - } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { + } else if (IsFragmentBuffer(src) && dst.scope() == "shared.tmem") { is_st = true; } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { is_cp = true; @@ -2068,8 +2068,7 @@ void CopyNode::CollectFragmentLayouts(const PrimExpr &expr, Map &result_map) const { PostOrderVisit(expr, [&](const ObjectRef &node) { if (auto bl = node.as()) { - if (bl->buffer.scope() == "local.fragment" && - !existing_layouts.count(bl->buffer) && + if (IsFragmentBuffer(bl->buffer) && !existing_layouts.count(bl->buffer) && !result_map.count(bl->buffer)) { auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent); result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds)); diff --git a/src/op/fill.cc b/src/op/fill.cc index 794b38401..bc539da93 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -156,7 +156,7 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { * @return Stmt The lowered TIR statement implementing the fill. */ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (dst.scope() == "local.fragment") { + if (IsFragmentBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, @@ -174,12 +174,11 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { vectorized_thread_loop); } return vectorized_thread_loop; - } else if (dst.scope() == "local") { + } else if (IsLocalBuffer(dst)) { auto init_loop = MakeSIMTLoop(analyzer); auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); return vectorized_thread_loop; - } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || - dst.scope() == "global") { + } else if (IsSharedBuffer(dst) || IsGlobalBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 57c02b0b5..5a8fa3070 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -520,17 +520,17 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } - if (a_.scope() == "local.fragment") { - ICHECK(b_.scope() != "local.fragment"); + if (IsFragmentBuffer(a_)) { + ICHECK(!IsFragmentBuffer(b_)); ICHECK(!transA_) << "gemm_rs requires the A operand to be in non-transposed layout."; op_name = "tl::gemm_rs"; - } else if (b_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(b_)) { op_name = "tl::gemm_sr"; } else { op_name = "tl::gemm_ss"; } - ICHECK(c_.scope() == "local.fragment"); + ICHECK(IsFragmentBuffer(c_)); ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; ss << warp_m << ", " << warp_n << ", "; @@ -602,7 +602,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, auto [warp_m, warp_n] = policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); if (TargetIsVolta(T.target)) { - ICHECK(c_.scope() == "local.fragment") + ICHECK(IsFragmentBuffer(c_)) << "Volta gemm only supports C in local.fragment scope, got " << c_.scope(); auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, @@ -613,7 +613,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]), *as_const_int(a_->shape[dim_A - 1]), true, !transA_)); - } else if (a_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(a_)) { ICHECK(transA_ == false); auto fragment = makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n); @@ -630,7 +630,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || TargetIsSM120(T.target) || (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) { - ICHECK(c_.scope() == "local.fragment") + ICHECK(IsFragmentBuffer(c_)) << "MMA only supports C in local.fragment scope, got " << c_.scope(); auto fragment = @@ -644,7 +644,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, results.Set(a_, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, a_->dtype.bits(), !transA_)); - } else if (a_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(a_)) { auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n, a_->dtype.bits(), transA_); results.Set(a_, fragment->BindThreadRange(thread_range)); @@ -658,7 +658,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, results.Set(b_, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, b_->dtype.bits(), transB_)); - } else if (b_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(b_)) { auto fragment = makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); results.Set(b_, fragment->BindThreadRange(thread_range)); @@ -666,7 +666,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ICHECK(0); } } else if (TargetIsHopper(T.target)) { - ICHECK(c_.scope() == "local.fragment") + ICHECK(IsFragmentBuffer(c_)) << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ") << "only supports C in local.fragment scope, got " << c_.scope(); auto fragment = gemm_inst == GemmInst::kWGMMA @@ -772,7 +772,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, results.Set(c_, res); } } else if (TargetIsCDNA(T.target)) { - ICHECK(c_.scope() == "local.fragment") + ICHECK(IsFragmentBuffer(c_)) << "CDNA gemm (FMMA) only supports C in local.fragment scope, got " << c_.scope(); auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n, @@ -785,7 +785,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, *as_const_int(a_->shape[dim_A - 2]), *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_); results.Set(a_, shared_layout); - } else if (a_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(a_)) { auto fragment = makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n, a_->dtype.bits(), kPack_, transA_); @@ -800,7 +800,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_); results.Set(b_, shared_layout); - } else if (b_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(b_)) { auto fragment = makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); results.Set(b_, fragment->BindThreadRange(thread_range)); diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 4c0ae08b9..828953460 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -221,7 +221,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, if (completed_) return {}; LayoutMap results; - ICHECK(c_.scope() == "local.fragment"); + ICHECK(IsFragmentBuffer(c_)); auto thread_range = T.thread_bounds; auto block_size = *as_const_int(thread_range->extent); if (TargetIsHopper(T.target)) { @@ -273,7 +273,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); results.Set(a_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, a_->dtype.bits())); - } else if (a_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(a_)) { // auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, // A->dtype.bits(), trans_A); // results.Set(A, fragment->BindThreadRange(thread_range)); @@ -287,7 +287,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); results.Set(b_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, b_->dtype.bits())); - } else if (b_.scope() == "local.fragment") { + } else if (IsFragmentBuffer(b_)) { // auto fragment = // makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); // results.Set(B, fragment->BindThreadRange(thread_range)); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index dbc6ea8e2..a5e0b844c 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -12,6 +12,7 @@ #include "../target/utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" +#include "utils.h" namespace tvm { namespace tl { @@ -147,7 +148,7 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { } void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) << op->buffer << ": " << op->indices << " and " @@ -161,7 +162,7 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { } void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) << op->buffer << ": " << op->indices << " and " @@ -191,8 +192,9 @@ void ParallelOpNode::ExpandLetBindings( std::function expand = [&](const PrimExpr &expr) { PostOrderVisit(expr, [&](const ObjectRef &node) { if (auto bl = node.as()) { - if (bl->buffer.scope() == "local.fragment" && - !indice_map_.count(bl->buffer)) { + if (IsFragmentBuffer(bl->buffer) && !indice_map_.count(bl->buffer)) { + LOG(INFO) << "ExpandLetBindings: set buffer " << bl->buffer + << " with indices " << bl->indices; indice_map_.Set(bl->buffer, bl->indices); } } else if (auto var_node = node.as()) { @@ -204,9 +206,20 @@ void ParallelOpNode::ExpandLetBindings( }); }; - // Scan all let bindings + // Only expand let bindings that are used in root_ + // First, collect all vars used in root_ + std::unordered_set used_vars; + PostOrderVisit(root_, [&](const ObjectRef &node) { + if (auto var_node = node.as()) { + used_vars.insert(var_node); + } + }); + + // Only expand let bindings for vars that are actually used in root_ for (const auto &[var, expr] : let_var_to_expr) { - expand(expr); + if (used_vars.count(var.get())) { + expand(expr); + } } } @@ -259,7 +272,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, if (T.layout_map.count(buffer)) { continue; } - if (buffer.scope() != "local.fragment") + if (!IsFragmentBuffer(buffer)) continue; // Check if all indices are zero @@ -303,7 +316,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, return results; } auto buffer_is_completed_replicated = [&](const Buffer &buffer) { - if (buffer.scope() != "local.fragment") + if (!IsFragmentBuffer(buffer)) return false; auto frag = T.layout_map[buffer].as().value(); // buffer indices should be IntImm @@ -319,7 +332,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Collect fragment buffers with const index and all fragment_buffers std::vector const_index_fragment_buffer, fragment_buffers; for (const auto &[buffer, indices] : indice_map_) { - if (buffer.scope() != "local.fragment") + if (!IsFragmentBuffer(buffer)) continue; fragment_buffers.push_back(buffer); @@ -472,7 +485,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" || buffer.scope() == "global") { store_shared_global_buffers.emplace_back(buffer); - } else if (buffer.scope() == "local.fragment") { + } else if (IsFragmentBuffer(buffer)) { store_fragment_buffers.emplace_back(buffer); } } diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 4458a4f51..91beee492 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -380,7 +380,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, if (level >= InferLevel::kStrict) return {}; - if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + if (IsFragmentBuffer(src) && IsFragmentBuffer(dst) && T.layout_map.count(src)) { auto src_layout = T.layout_map[src].as().value(); @@ -518,8 +518,7 @@ CumSumOp::CumSumOp(Array args) { } Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (this->src.scope() == "local.fragment" && - this->dst.scope() == "local.fragment") { + if (IsFragmentBuffer(this->src) && IsFragmentBuffer(this->dst)) { LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue " "if you need this feature."; } else if (this->src.scope() == "shared.dyn" || diff --git a/src/op/utils.h b/src/op/utils.h index d386b1a58..1a0fd30a6 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -29,6 +29,29 @@ TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, bool require_2d = false); +// Check if a buffer is a fragment buffer (scope == "local.fragment") +inline bool IsFragmentBuffer(const Buffer &buffer) { + return buffer.defined() && buffer.scope() == "local.fragment"; +} + +inline bool IsSharedBuffer(const Buffer &buffer, bool allow_dynamic = true) { + if (allow_dynamic) { + return buffer.defined() && + (buffer.scope() == "shared" || buffer.scope() == "shared.dyn"); + } else { + return buffer.defined() && buffer.scope() == "shared"; + } +} + +inline bool IsGlobalBuffer(const Buffer &buffer) { + return buffer.defined() && buffer.scope() == "global"; +} + +inline bool IsLocalBuffer(const Buffer &buffer) { + return buffer.defined() && + (buffer.scope() == "local" || buffer.scope() == "local.var"); +} + } // namespace tl } // namespace tvm diff --git a/src/transform/common/loop_parallel_transform_utils.h b/src/transform/common/loop_parallel_transform_utils.h index 1e8d7a350..52a5a9b97 100644 --- a/src/transform/common/loop_parallel_transform_utils.h +++ b/src/transform/common/loop_parallel_transform_utils.h @@ -17,6 +17,8 @@ #include "arith/ir_visitor_with_analyzer.h" #include +#include "../../op/utils.h" + namespace tvm { namespace tl { @@ -134,7 +136,7 @@ class ParallelLoopTransformer : public IRMutatorWithAnalyzer { class BufferAccessCollector : public StmtExprVisitor { public: void VisitExpr_(const BufferLoadNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { if (buffer_indices.find(op->buffer) == buffer_indices.end()) { buffer_indices[op->buffer] = op->indices; } else { @@ -147,7 +149,7 @@ class ParallelLoopTransformer : public IRMutatorWithAnalyzer { } void VisitStmt_(const BufferStoreNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { if (buffer_indices.find(op->buffer) == buffer_indices.end()) { buffer_indices[op->buffer] = op->indices; } else { diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 1af816147..9efe12f07 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -20,6 +20,7 @@ #include "../op/copy.h" #include "../op/parallel.h" #include "../op/region.h" +#include "../op/utils.h" #include "../target/utils.h" #include "arith/ir_mutator_with_analyzer.h" @@ -170,8 +171,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (layout_map.count(buffer)) { // If new layout contains the old one, update map - if (buffer.scope() == "local.fragment" && - level != InferLevel::kStrict && !strict_layout_map.count(buffer)) { + if (IsFragmentBuffer(buffer) && level != InferLevel::kStrict && + !strict_layout_map.count(buffer)) { // Actually this test has been done in ParallelOp::InferLayout // already. Just do it again to avoid missing implementations in other // `TileOperator`s. @@ -308,6 +309,17 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { q.push_back(i); } + // step 0: set fully replicated layout for floating fragment buffers + // Floating buffers are accessed outside TileOps (e.g., in if conditions), + // so they must be replicated across all threads. + for (const auto &[buffer, thread_bounds] : floating_fragment_buffers_) { + if (layout_map.count(buffer)) + continue; + auto frag = + Fragment::FullyReplicated(buffer->shape, thread_bounds->extent); + layout_map.Set(buffer, frag); + } + // step 1: infer strict layout for (int i = 0; i < num_infer; i++) { RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map, @@ -321,7 +333,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // step 2: infer common layout with BFS FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q, in_queue); - // step 3: relax constraints to free and re-run InferInFreeMode(layout_map, strict_layout_map); @@ -369,7 +380,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Check that all local.fragment buffers have inferred layouts for (const auto &[buffer, _] : use_list_) { - if (buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(buffer)) { ICHECK_NE(layout_map.count(buffer), 0) << "The layout for fragment " << buffer << " can not be inferred correctly."; @@ -422,6 +433,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { << "Layout_Inference: Require the target attribute"; target_ = target.value(); this->operator()(f->body); + // Compute floating fragment buffers after collection + ComputeFloatingFragmentBuffers(f->body); } private: @@ -579,7 +592,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void addToUseList(const Buffer &buffer) { // buffer scope must be local.fragment - if (buffer.scope() != "local.fragment") { + if (!IsFragmentBuffer(buffer)) { return; } int infer_idx = infer_list_.size(); @@ -774,7 +787,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void CollectFragmentBuffersFromExpr(const PrimExpr &expr) { PostOrderVisit(expr, [this](const ObjectRef &node) { if (auto bl = node.as()) { - if (bl->buffer.defined() && bl->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(bl->buffer)) { addToUseList(bl->buffer); } } else if (auto var_node = node.as()) { @@ -846,11 +859,122 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { IRVisitorWithAnalyzer::VisitStmt_(op); } + // Compute floating fragment buffers after collection is done. + // + // A "floating" fragment buffer is one that has accesses outside of any + // TileOp (Copy, Gemm, Reduce, Parallel, etc.). For example: + // + // T.copy(BlockMask[by, :], block_mask_f) // block_mask_f accessed IN + // TileOp for i in T.Pipelined(N_S): + // if block_mask_f[i] >= 0: // block_mask_f accessed OUTSIDE + // TileOp (floating!) + // T.copy(A[...], A_shared) + // + // In this example, `block_mask_f[i]` in the if-condition is a "floating" + // access because it's not inside any TileOp. Such buffers need special + // handling: they must be fully replicated across all threads since the + // access pattern cannot be inferred from TileOp semantics. + // + // This function identifies these buffers by: + // 1. Collecting all IR nodes that are inside TileOps (from infer_list_stmt_) + // 2. Scanning the entire function body for fragment buffer accesses + // 3. Any access not inside a TileOp means the buffer is "floating" + // 4. Recording the thread_bounds at the point of each floating access + void ComputeFloatingFragmentBuffers(const Stmt &func_body) { + // Step 1: Collect all nodes that are inside TileOps + std::unordered_set nodes_in_tileops; + for (const auto &stmt : infer_list_stmt_) { + PostOrderVisit(stmt, [&](const ObjectRef &node) { + nodes_in_tileops.insert(node.get()); + }); + } + + // Step 2: Use a visitor to scan for floating accesses while tracking thread + // context + class FloatingBufferCollector : public IRVisitorWithAnalyzer { + public: + FloatingBufferCollector( + const std::unordered_set &nodes_in_tileops, + std::unordered_map + &floating_buffers) + : nodes_in_tileops_(nodes_in_tileops), + floating_buffers_(floating_buffers) {} + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode *op) final { + CheckFloatingAccess(op->buffer, op); + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + CheckFloatingAccess(op->buffer, op); + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + private: + void CheckFloatingAccess(const Buffer &buffer, const Object *node) { + if (!IsFragmentBuffer(buffer)) + return; + if (nodes_in_tileops_.find(node) != nodes_in_tileops_.end()) + return; + // This is a floating access - record buffer with current thread_bounds + if (floating_buffers_.find(buffer) != floating_buffers_.end()) + return; // Already recorded + Range thread_bounds = Range::FromMinExtent(0, 1); + if (thread_var_.defined() && + analyzer_.const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto dtype = thread_var_->var.dtype(); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + thread_bounds = Range::FromMinExtent( + IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)); + } + floating_buffers_[buffer] = thread_bounds; + } + + const std::unordered_set &nodes_in_tileops_; + std::unordered_map + &floating_buffers_; + IterVar thread_var_; + }; + + FloatingBufferCollector collector(nodes_in_tileops, + floating_fragment_buffers_); + collector(func_body); + + // Debug log floating fragment buffers + if (!floating_fragment_buffers_.empty()) { + DLOG(INFO) + << "Floating fragment buffers (have accesses outside TileOps):"; + for (const auto &[buffer, thread_bounds] : floating_fragment_buffers_) { + DLOG(INFO) << " " << buffer + << " with thread_bounds: " << thread_bounds; + } + } + } + Map> buffer_data_to_buffers_; // Map from LetStmt variable to its bound expression Map let_var_to_expr_; std::vector infer_list_stmt_; std::vector infer_list_; + // Fragment buffers that have accesses outside of TileOps. + // These "floating" buffers need fully replicated layouts since their + // access patterns cannot be inferred from TileOp semantics. + // Maps buffer -> thread_bounds at the point of floating access. + // See ComputeFloatingFragmentBuffers() for detailed explanation. + std::unordered_map + floating_fragment_buffers_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> use_list_; // Per-op list of buffers it touches (fragment scope), used for prioritization @@ -1141,7 +1265,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { bool store_into_local = false; PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - if (store->buffer.scope() == "local") { + if (IsLocalBuffer(store->buffer)) { store_into_local = true; } // if the case is like: @@ -1161,11 +1285,11 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { bool local_register_only = true; PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - if (store->buffer.scope() != "local") { + if (!IsLocalBuffer(store->buffer)) { local_register_only = false; } } else if (const auto *load = obj.as()) { - if (load->buffer.scope() != "local") { + if (!IsLocalBuffer(load->buffer)) { local_register_only = false; } } @@ -1186,12 +1310,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { if (const auto *load = obj.as()) { String scope = load->buffer.scope(); - if (scope != "local" && scope != "local.fragment") { + if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) { has_non_local = true; } } else if (const auto *store = obj.as()) { String scope = store->buffer.scope(); - if (scope != "local" && scope != "local.fragment") { + if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) { has_non_local = true; } } diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 1a9da919c..a6f31da7d 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -262,21 +262,6 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitStmt_(op); } - bool IsLocalBuffer(const Buffer &buffer) { - String scope = buffer.scope(); - return scope == "local" || scope == "local.fragment" || - scope == "local.var"; - } - - bool isSharedBuffer(const Buffer &buffer) { - String scope = buffer.scope(); - return scope == "shared" || scope == "shared.dyn"; - } - - bool IsGlobalBuffer(const Buffer &buffer) { - String scope = buffer.scope(); - return scope == "global"; - } // Get the safe value of the buffer PrimExpr GetSafeValue(const Buffer &buffer) { if (annotated_safe_value_map_.count(buffer)) { diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index b4236c6db..186af7340 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -28,6 +28,8 @@ #include +#include "../op/utils.h" + namespace tvm { namespace tl { @@ -218,14 +220,14 @@ class LoopPartitioner : public StmtExprVisitor { private: void VisitExpr_(const BufferLoadNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { has_fragment_ = true; } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) final { - if (op->buffer.scope() == "local.fragment") { + if (IsFragmentBuffer(op->buffer)) { has_fragment_ = true; } StmtExprVisitor::VisitStmt_(op);