Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
}
arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
Expand Down
31 changes: 1 addition & 30 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,37 +649,8 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout layout_before_inv = Layout(loop_vars_, bijective_indice);

// Pre-check cardinality to guard non-bijective combinations after adding
// rep_b.
PrimExpr in_prod = 1;
for (const auto &iv : loop_vars_)
in_prod *= iv->dom->extent;
PrimExpr out_prod = 1;
for (const auto &d : layout_before_inv->OutputShape())
out_prod *= d;

if (!analyzer_.CanProveEqual(in_prod, out_prod)) {
DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling "
"back to no-rep inversion.";
Layout ind_inv_fallback =
Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv_fallback->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that bijective mapping is guaranteed before unconditional inversion.

The removal of the pre-check for bijective mapping and the fallback logic means that Layout::Inverse() is now called unconditionally on line 652. If bijective_indice (created by appending rep_b from DivideUnusedIterators) is not actually bijective, this could lead to runtime failures or incorrect layout computations.

Potential risks:

  • Layout::Inverse() may fail or produce incorrect results for non-bijective mappings
  • Error messages may be less clear without the pre-check
  • Silent correctness issues if the inverse is computed incorrectly

Please verify that DivideUnusedIterators guarantees a bijective mapping in all cases. Run the following script to check for existing test coverage of edge cases:

Additionally, consider:

  1. Adding explicit checks or assertions that validate bijectivity before inversion
  2. Documenting why the unconditional inversion is safe (if it is proven to be)
  3. Adding test cases that exercise non-trivial index patterns to ensure correctness

🏁 Script executed:

#!/bin/bash
# Description: Check for test coverage of CompleteBufferFragment and bijective mapping edge cases

# Search for tests related to CompleteBufferFragment
echo "=== Tests for CompleteBufferFragment ==="
rg -n "CompleteBufferFragment" --type=cpp --type=py

# Search for tests involving DivideUnusedIterators
echo -e "\n=== Tests for DivideUnusedIterators ==="
rg -n "DivideUnusedIterators" --type=cpp --type=py

# Search for bijective-related tests in parallel operations
echo -e "\n=== Bijective mapping tests ==="
rg -nC3 "bijective.*parallel|parallel.*bijective" --type=cpp --type=py -g '!*.md'

# Look for any comments or documentation about bijective guarantees
echo -e "\n=== Documentation about bijective guarantees ==="
rg -nC2 "DivideUnusedIterators" --type=cpp -A5 -B5

Length of output: 4531


🏁 Script executed:

#!/bin/bash
# Check the implementation of DivideUnusedIterators and MakeFlattenedExpression
echo "=== DivideUnusedIterators implementation ==="
sed -n '122,160p' src/layout/utils.cc

echo -e "\n=== MakeFlattenedExpression signature ==="
rg -A10 "^PrimExpr MakeFlattenedExpression" src/layout/utils.cc

echo -e "\n=== Layout::Inverse() function ==="
rg -n "Inverse()" src/layout/layout.h

echo -e "\n=== Layout::Inverse() implementation ==="
rg -A20 "Layout.*::Inverse()" src/layout/layout.cc

echo -e "\n=== Check for bijective-related comments or checks ==="
rg -nC3 "bijective" src/op/parallel.cc

echo -e "\n=== Look for any error handling or assertions around Inverse ==="
rg -B5 -A5 "->Inverse()" src/op/parallel.cc

Length of output: 995


🏁 Script executed:

#!/bin/bash
# Check the implementation of DivideUnusedIterators and MakeFlattenedExpression
echo "=== DivideUnusedIterators implementation ==="
sed -n '122,160p' src/layout/utils.cc

echo -e "\n=== MakeFlattenedExpression signature ==="
rg -A10 "^PrimExpr MakeFlattenedExpression" src/layout/utils.cc

echo -e "\n=== Layout::Inverse() function ==="
rg -n "Inverse()" src/layout/layout.h

echo -e "\n=== Layout::Inverse() implementation ==="
rg -A20 "Layout.*::Inverse()" src/layout/layout.cc

echo -e "\n=== Check for bijective-related comments or checks ==="
rg -nC3 "bijective" src/op/parallel.cc

echo -e "\n=== Look for error handling or assertions around Inverse in parallel.cc ==="
rg -B5 -A5 "Inverse" src/op/parallel.cc | head -50

Length of output: 9042


🏁 Script executed:

#!/bin/bash
# Check the git history of the file to see what was removed
echo "=== Git history of line 652 changes ==="
git log --oneline -n 20 src/op/parallel.cc

echo -e "\n=== Check if InverseWithLevel has parameters for handling non-bijective ==="
rg -A30 "LayoutNode::InverseWithLevel()" src/layout/layout.cc | head -40

echo -e "\n=== Check what level parameter means in DetectIterMap ==="
rg -B3 "IterMapLevel" src/layout/layout.cc | head -30

echo -e "\n=== Check if there's any old fallback code in git ==="
git log -p --follow -S "pre-check\|fallback" src/op/parallel.cc | head -100

Length of output: 3546


🏁 Script executed:

#!/bin/bash
# Get the full diff of the most relevant commit
echo "=== Full diff of commit 2957afc ==="
git show 2957afc -- src/op/parallel.cc | head -200

echo -e "\n=== Look for InverseWithLevel usage in the same file ==="
rg -n "InverseWithLevel\|Inverse()" src/op/parallel.cc

Length of output: 4193


Remove unconditional Inverse() call or restore bijectivity pre-check and fallback handling.

The current code at line 652 removes critical safety mechanisms that were previously in place. The git history (commit 2957afc) shows the code previously included:

  1. Pre-check for bijectivity: CanProveEqual(in_prod, out_prod) to verify the combined mapping was bijective
  2. Fallback path: If the check failed, it would use Layout(loop_vars_, indice_map_[buffer])->Inverse() without rep_b

Without these safeguards, calling Inverse() unconditionally can fail if bijective_indice is not actually bijective (for example, if DivideUnusedIterators output combined with the original indices creates a non-bijective mapping). The 2D path (lines 626–632) validates bijectivity explicitly before inversion; the fallback path should too.

Options:

  • Restore the cardinality pre-check and fallback handling
  • Or add explicit error handling around the Inverse() call with recovery logic
  • Or document why bijectivity is now guaranteed


Layout ind_inv = layout_before_inv->Inverse();
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
Expand Down
Loading