Skip to content

Commit bc4c61c

Browse files
committed
test for empty domain/range and symbols
1 parent f968bbb commit bc4c61c

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,29 +2490,35 @@ IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
24902490

24912491
// explicit copy of `this`
24922492
IntegerRelation result = *this;
2493+
2494+
// An explicit copy of `rel` is needed to merge and align symbols, since that
2495+
// function mutates both relations.
2496+
IntegerRelation relCopy = rel;
2497+
result.mergeAndAlignSymbols(relCopy);
2498+
24932499
unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range);
2494-
unsigned numRelRangeVars = rel.getNumRangeVars();
24952500
unsigned numThisRangeVars = getNumRangeVars();
2501+
unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars();
24962502

2497-
result.appendVar(VarKind::Range, numRelRangeVars);
2503+
result.appendVar(VarKind::Range, rel.getNumRangeVars());
24982504

24992505
// Copy each equality from `rel` and update the copy to account for range
25002506
// variables from `this`. The `rel` equality is a list of coefficients of the
25012507
// variables from `rel`, and so the range variables need to be shifted right
2502-
// by the number of `this` range variables.
2508+
// by the number of `this` range variables and symbols.
25032509
for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
25042510
SmallVector<DynamicAPInt> copy =
25052511
SmallVector<DynamicAPInt>(rel.getEquality(i));
2506-
copy.insert(copy.begin() + relRangeVarStart, numThisRangeVars,
2507-
DynamicAPInt(0));
2512+
copy.insert(copy.begin() + relRangeVarStart,
2513+
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
25082514
result.addEquality(copy);
25092515
}
25102516

25112517
for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
25122518
SmallVector<DynamicAPInt> copy =
25132519
SmallVector<DynamicAPInt>(rel.getInequality(i));
2514-
copy.insert(copy.begin() + relRangeVarStart, numThisRangeVars,
2515-
DynamicAPInt(0));
2520+
copy.insert(copy.begin() + relRangeVarStart,
2521+
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
25162522
result.addInequality(copy);
25172523
}
25182524

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,52 @@ TEST(IntegerRelationTest, rangeProductMultdimRangeSwapped) {
653653

654654
EXPECT_TRUE(expected.isEqual(rangeProd));
655655
}
656+
657+
TEST(IntegerRelationTest, rangeProductEmptyDomain) {
658+
IntegerRelation r1 =
659+
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 0);
660+
IntegerRelation r2 =
661+
parseRelationFromSet("(k, l) : (2*k + 3*l == 0, k >= 0, l >= 0)", 0);
662+
IntegerRelation rangeProd = r1.rangeProduct(r2);
663+
IntegerRelation expected =
664+
parseRelationFromSet("(i, j, k, l) : (2*k + 3*l == 0, 4*i + 9*j == "
665+
"0, i >= 0, j >= 0, k >= 0, l >= 0)",
666+
0);
667+
EXPECT_TRUE(expected.isEqual(rangeProd));
668+
}
669+
670+
TEST(IntegerRelationTest, rangeProductEmptyRange) {
671+
IntegerRelation r1 =
672+
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 2);
673+
IntegerRelation r2 =
674+
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, i >= 0, j >= 0)", 2);
675+
IntegerRelation rangeProd = r1.rangeProduct(r2);
676+
IntegerRelation expected =
677+
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, 4*i + 9*j == "
678+
"0, i >= 0, j >= 0)",
679+
2);
680+
EXPECT_TRUE(expected.isEqual(rangeProd));
681+
}
682+
683+
TEST(IntegerRelationTest, rangeProductEmptyDomainAndRange) {
684+
IntegerRelation r1 = parseRelationFromSet("() : ()", 0);
685+
IntegerRelation r2 = parseRelationFromSet("() : ()", 0);
686+
IntegerRelation rangeProd = r1.rangeProduct(r2);
687+
IntegerRelation expected = parseRelationFromSet("() : ()", 0);
688+
EXPECT_TRUE(expected.isEqual(rangeProd));
689+
}
690+
691+
TEST(IntegerRelationTest, rangeProductSymbols) {
692+
IntegerRelation r1 = parseRelationFromSet(
693+
"(i, j)[s] : (2*i + 3*j + s == 0, i >= 0, j >= 0)", 1);
694+
IntegerRelation r2 = parseRelationFromSet(
695+
"(i, l)[t] : (3*i + 4*l + t == 0, i >= 0, l >= 0)", 1);
696+
697+
IntegerRelation rangeProd = r1.rangeProduct(r2);
698+
IntegerRelation expected = parseRelationFromSet(
699+
"(i, j, k, l)[s, t] : (2*i + 3*j + s == 0, 3*i + 4*l + t == "
700+
"0, i >= 0, j >= 0, l >= 0)",
701+
1);
702+
703+
EXPECT_TRUE(expected.isEqual(rangeProd));
704+
}

0 commit comments

Comments
 (0)