Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/layout/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
return results;
}

// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator
// appears in fused forms across multiple index expressions. We first normalize
// every index into IterSumExpr, collect all splits per source Var, then
// consolidate them to avoid misclassifying a used split as unused.
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters,
Analyzer *analyzer) {
Expand All @@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
}

for (const IterVar &iter : input_iters) {
IterMark iv_mark;
// Merge splits from all IterMark that share the same source Var as `iter`.
std::vector<IterSplitExpr> merged_splits;
for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>()->same_as(iter->var)) { // NOLINT(*)
iv_mark = mark;
break;
auto vexpr = mark->source.as<Var>();
if (vexpr && vexpr.value().same_as(iter->var)) {
auto it = collector.mark2splits_.find(mark);
if (it != collector.mark2splits_.end()) {
const auto &vec = it->second;
merged_splits.insert(merged_splits.end(), vec.begin(), vec.end());
}
}
}
if (iv_mark.defined()) {
auto splits =
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
// Put the small axis last

if (!merged_splits.empty()) {
// Use a unified mark (Var + full extent) to compute the missing pieces
// so that fused usages are honored as "used" and not reintroduced.
IterMark unified_mark(iter->var, iter->dom->extent);
auto splits = get_unused_iters(unified_mark, merged_splits, analyzer);
// Put the small axis last for a flattened ordering.
results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) {
auto mark = IterMark(iter->var, iter->dom->extent);
Expand Down
57 changes: 56 additions & 1 deletion src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,11 +620,66 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
if (IsCommonAccessIndice(buffer)) {
return loop_layout_;
}
// Prefer a simple path: if original 2D indices form a bijective map, invert
// them directly and avoid introducing a synthetic replicate dimension.
{
auto res2d =
arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
arith::IterMapLevel::Bijective,
const_cast<arith::Analyzer *>(&analyzer_));
if (res2d->errors.empty()) {
Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b2 =
loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
Layout layout_before_inv = Layout(loop_vars_, bijective_indice);

// Pre-check cardinality to guard non-bijective combinations after adding
// rep_b.
PrimExpr in_prod = 1;
for (const auto &iv : loop_vars_)
in_prod *= iv->dom->extent;
PrimExpr out_prod = 1;
for (const auto &d : layout_before_inv->OutputShape())
out_prod *= d;

if (!analyzer_.CanProveEqual(in_prod, out_prod)) {
DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling "
"back to no-rep inversion.";
Layout ind_inv_fallback =
Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv_fallback->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}

Layout ind_inv = layout_before_inv->Inverse();
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
Expand Down
63 changes: 63 additions & 0 deletions testing/python/layout/test_tilelang_layout_fused_replicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
import torch

import tilelang
import tilelang.testing
import tilelang.language as T

tilelang.testing.set_random_seed()

VEC_SIZE = 32


@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):

@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K

for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]

Comment on lines +31 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Initialize the fragment before reading from it

a_out is filled from a_fp32_local, but that fragment is never written—a is ignored entirely. This leaves the store sourcing undefined data from uninitialized memory. Please load (or otherwise initialize) the fragment before using it and consume the input tensor.

Apply this diff to populate the fragment from a:

         for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
             idx = i * BLOCK_K + j
-            a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
+            a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] = T.cast(
+                a[pid_b, offs_m + i, offs_n + j], "float32"
+            )
+            a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] = T.cast(
a[pid_b, offs_m + i, offs_n + j], "float32"
)
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
🤖 Prompt for AI Agents
testing/python/layout/test_tilelang_layout_fused_replicate.py around lines 31 to
34: the code stores from a_fp32_local into a_out but never initializes
a_fp32_local from the input tensor a, so stores read uninitialized memory; fix
by loading/initializing the fragment before the Parallel store — perform the
corresponding read from the input tensor a into a_fp32_local (e.g., a.load or an
explicit loop that reads a into the fragment using the same indexing/vec layout)
so the fragment is consumed and then use that initialized fragment when writing
to a_out.

return main


def _require_cuda_tensor(shape, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")


def test_layout_infer_compiles_and_runs():
B, M, N = 1, 32, 64
BLOCK_MN, BLOCK_K = 32, 64
kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K)

a = _require_cuda_tensor((B, M, N), torch.bfloat16)
a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device)

# Ensure kernel compiles and executes without layout inversion failure
kernel(a, a_out)

assert a_out.shape == a.shape
assert a_out.dtype == torch.float32


if __name__ == "__main__":
tilelang.testing.main()
Loading