Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,21 @@ def main(
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
block_mask = T.alloc_fragment([downsample_len], block_mask_dtype)

T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

for vj in T.serial(downsample_len):
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
T.copy(BlockSparseMask[bz, by, bx, :], block_mask)

loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)

for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k]:
if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
Expand Down
5 changes: 2 additions & 3 deletions examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,16 +321,15 @@ def flash_bwd_dsink(
dsinks: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], dtype)

sink[0] = Sinks[bx]
sink = Sinks[bx]
T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])

return flash_bwd_dsink
Expand Down
5 changes: 2 additions & 3 deletions examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,15 @@ def flash_bwd_dsink(
dsinks: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], accum_dtype)

sink[0] = Sinks[bx]
sink = Sinks[bx]
T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
dsink_fragment[i] = -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])

return flash_bwd_dsink
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,14 @@ def blocksparse_flashattn(
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
block_mask = T.alloc_fragment([downsample_len], block_mask_dtype)

T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

for vj in T.serial(downsample_len):
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
T.copy(BlockSparseMask[bz, by, bx, :], block_mask)

loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,34 @@ def combine(
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], T.int32)

T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
max_split = T.alloc_var(T.int32)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
Comment on lines 148 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: max_split used uninitialized when all lse values are zero.

At line 150, max_split is only assigned inside the conditional if lse_local_split != 0. If all log-sum-exp values are zero (which can occur when no valid blocks are processed in any split), max_split is never initialized but is subsequently used in conditionals at lines 154 and 159.

In the previous array-based implementation, max_split[0] would have been zero-initialized, providing defined behavior. The scalar refactor introduces undefined behavior in this edge case.

🔎 Proposed fix: Initialize max_split before the loop
 T.clear(lse_logsum_local)
 T.clear(o_accum_local)
 lse_max_local = -T.infinity(accum_dtype)
+max_split = -1
 for k in T.serial(num_split):
     lse_local_split = glse[bz, by, k]
     if lse_local_split != 0:
         max_split = k
         lse_max_local = T.max(lse_max_local, glse[bz, by, k])

Initializing to -1 ensures that if no valid splits are found, the conditions k <= max_split at lines 154 and 159 will be false for all k ≥ 0, correctly skipping the accumulation loops.

🤖 Prompt for AI Agents
In examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
around lines 147 to 151, max_split is only assigned inside the if and can remain
uninitialized when all lse values are zero; initialize max_split to -1
immediately before the for k in T.serial(num_split) loop so that when no valid
splits are found the later checks (k <= max_split) are false and accumulation
loops are skipped, ensuring defined behavior in the all-zero case.


for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
if k <= max_split:
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
if k <= max_split[0]:
if k <= max_split:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
o_accum_local[i] += po_local[i] * scale_local
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,40 +125,34 @@ def combine(
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], T.int32)

T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
max_split = T.alloc_var(T.int32)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])

for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
if k <= max_split:
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
if k <= max_split[0]:
if k <= max_split:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
o_accum_local[i] += po_local[i] * scale_local
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,27 @@ def combine(
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)

T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
o_accum_local[i] += po_local[i] * scale_local
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]

Expand Down
28 changes: 11 additions & 17 deletions examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,31 +173,25 @@ def combine(
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)

T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
Comment on lines +176 to 196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: scale_local[0] indexing on a scalar variable.

scale_local is allocated as a scalar via T.alloc_var(accum_dtype) at line 179, but is incorrectly indexed as scale_local[0] at line 196. This will cause an error since scalars cannot be indexed.

🔎 Proposed fix
                 for i in T.Parallel(dim):
-                    o_accum_local[i] += po_local[i] * scale_local[0]
+                    o_accum_local[i] += po_local[i] * scale_local
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local
🤖 Prompt for AI Agents
In examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py around lines
176 to 196, scale_local is allocated as a scalar (T.alloc_var(accum_dtype)) but
later used with an index (scale_local[0]) which is invalid; change the usage to
treat scale_local as a scalar (remove the [0] indexing) wherever it is used
(e.g., replace scale_local[0] with scale_local) so the multiplication into
o_accum_local uses the scalar value, or alternatively allocate scale_local as a
1-element buffer if indexing is required—prefer the scalar fix for minimal
change.

for i in T.Parallel(dim):
Expand Down
30 changes: 12 additions & 18 deletions examples/deepseek_mla/example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,33 +167,27 @@ def combine(
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)

T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
lse_max_local = T.max(lse_max_local, glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
lse_local_split = glse[bz, hid, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
lse_local_split = glse[bz, hid, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
o_accum_local[i] += po_local[i] * scale_local
for i in T.Parallel(dim):
Output[bz, hid, i] = o_accum_local[i]

Expand Down
Loading
Loading