diff --git a/src/layout/layout.cc b/src/layout/layout.cc index de428fc59..892f13770 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -250,9 +250,9 @@ std::pair LayoutNode::InverseWithLevel() const { if (!is_static_shape) { // Runtime guards keep dynamic tails safe, so we allow NoCheck here and // warn. - LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to " - "NoCheck; symbolic dims: " - << symbolic_dims; + DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to " + "NoCheck; symbolic dims: " + << symbolic_dims; } arith::IterMapResult res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 95817d179..81777aa53 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -649,37 +649,8 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); auto bijective_indice = indice_map_[buffer]; bijective_indice.push_back(rep_b); - Layout layout_before_inv = Layout(loop_vars_, bijective_indice); - - // Pre-check cardinality to guard non-bijective combinations after adding - // rep_b. - PrimExpr in_prod = 1; - for (const auto &iv : loop_vars_) - in_prod *= iv->dom->extent; - PrimExpr out_prod = 1; - for (const auto &d : layout_before_inv->OutputShape()) - out_prod *= d; - - if (!analyzer_.CanProveEqual(in_prod, out_prod)) { - DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling " - "back to no-rep inversion."; - Layout ind_inv_fallback = - Layout(loop_vars_, indice_map_[buffer])->Inverse(); - PrimExpr indice_rep_extent = 1; - PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); - PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; - Array fwd2; - for (size_t i = 0; i < buffer->shape.size(); i++) { - fwd2.push_back(InputPlaceholder(i)); - } - PrimExpr thd_b = loop_layout_->ForwardThread( - ind_inv_fallback->Forward(fwd2), std::nullopt); - return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, - std::nullopt) - ->CondenseReplicateVar(); - } + Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); - Layout ind_inv = layout_before_inv->Inverse(); PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();