Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/gemm_sp/example_custom_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,28 +258,28 @@ def kernel(
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_var(dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_shared((elem,), dtype=T.uint8)
non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
non_zero_cnt = 0
non_zero_cnt[0] = 0
for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt] = val
non_zero_cnt += 1
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if non_zero_cnt == 1 and non_zero_elt_log_idx[0] == 3:
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt == 1:
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
Expand Down
2 changes: 1 addition & 1 deletion src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
vectorized_thread_loop);
}
return vectorized_thread_loop;
} else if (IsLocalBuffer(dst)) {
} else if (IsLocalBuffer(dst) || IsLocalVarBuffer(dst)) {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop;
Comment on lines +177 to 180
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for IsLocalBuffer usage throughout the codebase
rg -nP '\bIsLocalBuffer\s*\(' --type=cpp -C3

Repository: tile-ai/tilelang

Length of output: 3716


🏁 Script executed:

#!/bin/bash
# Search for IsLocalVarBuffer definition and usage
rg -nP '\bIsLocalVarBuffer\s*\(' --type=cpp -C3

Repository: tile-ai/tilelang

Length of output: 796


🏁 Script executed:

#!/bin/bash
# Look at copy.cc around the IsLocalBuffer checks to understand the context
sed -n '770,790p' src/op/copy.cc

Repository: tile-ai/tilelang

Length of output: 944


Update copy operations and layout inference to handle local.var buffers consistently.

The fill.cc change correctly extends local buffer handling to include "local.var" buffers by checking IsLocalBuffer(dst) || IsLocalVarBuffer(dst). However, other operations in the codebase do not follow this same pattern:

  • src/op/copy.cc (lines 779-780) checks only IsLocalBuffer() for both source and destination buffers without considering IsLocalVarBuffer()
  • src/transform/layout_inference.cc has multiple IsLocalBuffer() checks (lines 1268, 1288, 1292, 1313, 1318) that similarly lack IsLocalVarBuffer() checks

These operations should be updated to treat "local.var" buffers the same as "local" buffers where appropriate, consistent with the fill.cc changes.

🤖 Prompt for AI Agents
In src/op/fill.cc lines 177-180 you expanded handling to treat local.var like
local buffers; mirror that change in src/op/copy.cc (around lines 779-780) and
src/transform/layout_inference.cc (around lines 1268, 1288, 1292, 1313, 1318) by
updating all IsLocalBuffer(...) checks to also accept IsLocalVarBuffer(...).
Replace patterns that only call IsLocalBuffer(x) with a combined predicate
(IsLocalBuffer(x) || IsLocalVarBuffer(x)) for both source and destination
checks, ensuring logic and short-circuit behavior remain the same and update any
related comments/tests if present.

Expand Down
4 changes: 4 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ inline bool IsLocalBuffer(const Buffer &buffer) {
return buffer.defined() && buffer.scope() == "local";
}

inline bool IsLocalVarBuffer(const Buffer &buffer) {
return buffer.defined() && buffer.scope() == "local.var";
}

} // namespace tl
} // namespace tvm

Expand Down
Loading