diff --git a/src/layout/utils.cc b/src/layout/utils.cc index 4f533c442..a2a788b24 100644 --- a/src/layout/utils.cc +++ b/src/layout/utils.cc @@ -115,6 +115,10 @@ Array get_unused_iters(const IterMark &mark, return results; } +// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator +// appears in fused forms across multiple index expressions. We first normalize +// every index into IterSumExpr, collect all splits per source Var, then +// consolidate them to avoid misclassifying a used split as unused. Array DivideUnusedIterators(const Array &exprs, const Array input_iters, Analyzer *analyzer) { @@ -134,17 +138,25 @@ Array DivideUnusedIterators(const Array &exprs, } for (const IterVar &iter : input_iters) { - IterMark iv_mark; + // Merge splits from all IterMark that share the same source Var as `iter`. + std::vector merged_splits; for (const IterMark &mark : collector.visited_) { - if (mark->source.as()->same_as(iter->var)) { // NOLINT(*) - iv_mark = mark; - break; + auto vexpr = mark->source.as(); + if (vexpr && vexpr.value().same_as(iter->var)) { + auto it = collector.mark2splits_.find(mark); + if (it != collector.mark2splits_.end()) { + const auto &vec = it->second; + merged_splits.insert(merged_splits.end(), vec.begin(), vec.end()); + } } } - if (iv_mark.defined()) { - auto splits = - get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer); - // Put the small axis last + + if (!merged_splits.empty()) { + // Use a unified mark (Var + full extent) to compute the missing pieces + // so that fused usages are honored as "used" and not reintroduced. + IterMark unified_mark(iter->var, iter->dom->extent); + auto splits = get_unused_iters(unified_mark, merged_splits, analyzer); + // Put the small axis last for a flattened ordering. results.insert(results.end(), splits.rbegin(), splits.rend()); } else if (!is_one(iter->dom->extent)) { auto mark = IterMark(iter->var, iter->dom->extent); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 118a9e74b..95817d179 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -620,11 +620,66 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { if (IsCommonAccessIndice(buffer)) { return loop_layout_; } + // Prefer a simple path: if original 2D indices form a bijective map, invert + // them directly and avoid introducing a synthetic replicate dimension. + { + auto res2d = + arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1, + arith::IterMapLevel::Bijective, + const_cast(&analyzer_)); + if (res2d->errors.empty()) { + Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse(); + PrimExpr indice_rep_extent = 1; + PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; + Array fwd2; + for (size_t i = 0; i < buffer->shape.size(); i++) { + fwd2.push_back(InputPlaceholder(i)); + } + PrimExpr thd_b2 = + loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt); + return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent, + std::nullopt) + ->CondenseReplicateVar(); + } + } + // Otherwise, infer an extra flattened iterator that captures truly-unused + // pieces of the loop space (if any), then try inversion with it. PrimExpr rep_b = MakeFlattenedExpression( DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); auto bijective_indice = indice_map_[buffer]; bijective_indice.push_back(rep_b); - Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); + Layout layout_before_inv = Layout(loop_vars_, bijective_indice); + + // Pre-check cardinality to guard non-bijective combinations after adding + // rep_b. + PrimExpr in_prod = 1; + for (const auto &iv : loop_vars_) + in_prod *= iv->dom->extent; + PrimExpr out_prod = 1; + for (const auto &d : layout_before_inv->OutputShape()) + out_prod *= d; + + if (!analyzer_.CanProveEqual(in_prod, out_prod)) { + DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling " + "back to no-rep inversion."; + Layout ind_inv_fallback = + Layout(loop_vars_, indice_map_[buffer])->Inverse(); + PrimExpr indice_rep_extent = 1; + PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; + Array fwd2; + for (size_t i = 0; i < buffer->shape.size(); i++) { + fwd2.push_back(InputPlaceholder(i)); + } + PrimExpr thd_b = loop_layout_->ForwardThread( + ind_inv_fallback->Forward(fwd2), std::nullopt); + return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, + std::nullopt) + ->CondenseReplicateVar(); + } + + Layout ind_inv = layout_before_inv->Inverse(); PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); diff --git a/testing/python/layout/test_tilelang_layout_fused_replicate.py b/testing/python/layout/test_tilelang_layout_fused_replicate.py new file mode 100644 index 000000000..d67a87bc3 --- /dev/null +++ b/testing/python/layout/test_tilelang_layout_fused_replicate.py @@ -0,0 +1,63 @@ +import pytest +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + +tilelang.testing.set_random_seed() + +VEC_SIZE = 32 + + +@tilelang.jit +def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): + + @T.prim_func + def main( + a: T.Buffer((B, M, N), "bfloat16"), + a_out: T.Buffer((B, M, N), "float32"), + ): + with T.Kernel( + T.ceildiv(M, BLOCK_MN), + T.ceildiv(N, BLOCK_K), + B, + threads=128, + ) as (pid_m, pid_n, pid_b): + a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") + offs_m = pid_m * BLOCK_MN + offs_n = pid_n * BLOCK_K + + for i, j in T.Parallel(BLOCK_MN, BLOCK_K): + idx = i * BLOCK_K + j + a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE] + + return main + + +def _require_cuda_tensor(shape, dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randn(*shape, device="cuda", dtype=dtype) + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +def test_layout_infer_compiles_and_runs(): + B, M, N = 1, 32, 64 + BLOCK_MN, BLOCK_K = 32, 64 + kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K) + + a = _require_cuda_tensor((B, M, N), torch.bfloat16) + a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device) + + # Ensure kernel compiles and executes without layout inversion failure + kernel(a, a_out) + + assert a_out.shape == a.shape + assert a_out.dtype == torch.float32 + + +if __name__ == "__main__": + tilelang.testing.main()