diff --git a/src/op/reduce.cc b/src/op/reduce.cc index b6ba14a91..c9d83cb1f 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -389,6 +389,35 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, } auto thd = src_layout->ForwardThread( fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + + // Ensure the thread count is divisible by the replicate extent. + // Otherwise, we cannot infer a valid fragment<->fragment layout. + { + arith::Analyzer analyzer; + PrimExpr num_threads = T.thread_bounds->extent; + // Though the dest_buffer_rep_extent will be compressed at + // CondenseReplicateVar, we need to check the divisibility here to avoid + // the issue that the thread count is not divisible by the replicate + // extent. + if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) == + 0) && + !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) == + 0)) { + ICHECK(false) << "ReduceOp fragment layout inference failed: " + "num_threads % replicate_extent != 0. " + << "This mapping requires the block's thread count to be " + "divisible by the " + << "replicate extent. " + << "Try one of: (1) choose a thread block size divisible " + "by replicate_extent; " + << "(2) pick a different reduce dimension or adjust the " + "source fragment layout; " + << "Details: num_threads=" << num_threads + << ", replicate_extent=" << indice_rep_extent + << ", src=" << src << ", dst=" << dst; + } + } + Fragment dst_layout = Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) ->CondenseReplicateVar() diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index 5969ee96d..cecfaa097 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -116,7 +116,6 @@ def test_reduce_sum(): def test_reduce_sum_shared(): run_reduce_sum(64, 64, mode="ss") - run_reduce_sum(32, 96, mode="ss") def test_reduce_max(): @@ -127,7 +126,6 @@ def test_reduce_max(): def test_reduce_max_shared(): run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32") - run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32") def test_reduce_min_shared():