Skip to content

Commit deeef02

Browse files
committed
Fix the semi join result mismatch issue with filter and multi duplicated rows
1 parent c5c804a commit deeef02

File tree

3 files changed

+100
-22
lines changed

3 files changed

+100
-22
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ void MergeJoin::initialize() {
9090
initializeFilter(joinNode_->filter(), leftType, rightType);
9191

9292
if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
93-
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
93+
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
94+
joinNode_->isLeftSemiFilterJoin() ||
95+
joinNode_->isRightSemiFilterJoin()) {
9496
joinTracker_ = JoinTracker(outputBatchSize_, pool());
9597
}
9698
} else if (joinNode_->isAntiJoin()) {
@@ -421,7 +423,7 @@ bool MergeJoin::tryAddOutputRow(
421423
filterRightInputProjections_);
422424

423425
if (joinTracker_) {
424-
if (isRightJoin(joinType_)) {
426+
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
425427
// Record right-side row with a match on the left-side.
426428
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
427429
} else {
@@ -613,16 +615,17 @@ bool MergeJoin::addToOutputForLeftJoin() {
613615
// one match on the other side, we could explore specialized algorithms
614616
// or data structures that short-circuit the join process once a match
615617
// is found.
616-
for (size_t r = isLeftSemiFilterJoin(joinType_) ? numRightBatches - 1
617-
: firstRightBatch;
618+
for (size_t r = (isLeftSemiFilterJoin(joinType_) && !filter_)
619+
? numRightBatches - 1
620+
: firstRightBatch;
618621
r < numRightBatches;
619622
++r) {
620623
const auto rightBatch = rightMatch_->inputs[r];
621624
auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0;
622625
const auto rightEndRow = r == numRightBatches - 1
623626
? rightMatch_->endRowIndex
624627
: rightBatch->size();
625-
if (isLeftSemiFilterJoin(joinType_)) {
628+
if (isLeftSemiFilterJoin(joinType_) && !filter_) {
626629
rightStartRow = rightEndRow - 1;
627630
}
628631
if (prepareOutput(leftBatch, rightBatch)) {
@@ -693,16 +696,17 @@ bool MergeJoin::addToOutputForRightJoin() {
693696
// one match on the other side, we could explore specialized algorithms
694697
// or data structures that short-circuit the join process once a match
695698
// is found.
696-
for (size_t l = isRightSemiFilterJoin(joinType_) ? numLeftBatches - 1
697-
: firstLeftBatch;
699+
for (size_t l = (isRightSemiFilterJoin(joinType_) && !filter_)
700+
? numLeftBatches - 1
701+
: firstLeftBatch;
698702
l < numLeftBatches;
699703
++l) {
700704
const auto leftBatch = leftMatch_->inputs[l];
701705
auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0;
702706
const auto leftEndRow = l == numLeftBatches - 1
703707
? leftMatch_->endRowIndex
704708
: leftBatch->size();
705-
if (isRightSemiFilterJoin(joinType_)) {
709+
if (isRightSemiFilterJoin(joinType_) && !filter_) {
706710
// RightSemiFilter produce each row from the right at most once.
707711
leftStartRow = leftEndRow - 1;
708712
}
@@ -818,7 +822,7 @@ RowVectorPtr MergeJoin::getOutput() {
818822
continue;
819823
} else if (isAntiJoin(joinType_)) {
820824
output = filterOutputForAntiJoin(output);
821-
if (output) {
825+
if (output != nullptr && output->size() > 0) {
822826
return output;
823827
}
824828

@@ -1274,7 +1278,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12741278
// If all matches for a given left-side row fail the filter, add a row to
12751279
// the output with nulls for the right-side columns.
12761280
const auto onMiss = [&](auto row) {
1277-
if (isAntiJoin(joinType_)) {
1281+
if (isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_) ||
1282+
isRightSemiFilterJoin(joinType_)) {
12781283
return;
12791284
}
12801285
rawIndices[numPassed++] = row;
@@ -1346,18 +1351,26 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13461351
}
13471352
};
13481353

1354+
auto onMatch = [&](auto row) {
1355+
if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
1356+
rawIndices[numPassed++] = row;
1357+
}
1358+
};
1359+
13491360
for (auto i = 0; i < numRows; ++i) {
13501361
if (filterRows.isValid(i)) {
13511362
const bool passed = !decodedFilterResult_.isNullAt(i) &&
13521363
decodedFilterResult_.valueAt<bool>(i);
13531364

1354-
joinTracker_->processFilterResult(i, passed, onMiss);
1365+
joinTracker_->processFilterResult(i, passed, onMiss, onMatch);
13551366

13561367
if (isAntiJoin(joinType_)) {
13571368
if (!passed) {
13581369
rawIndices[numPassed++] = i;
13591370
}
1360-
} else {
1371+
} else if (
1372+
!isLeftSemiFilterJoin(joinType_) &&
1373+
!isRightSemiFilterJoin(joinType_)) {
13611374
if (passed) {
13621375
rawIndices[numPassed++] = i;
13631376
}
@@ -1371,26 +1384,27 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13711384

13721385
// Every time we start a new left key match, `processFilterResult()` will
13731386
// check if at least one row from the previous match passed the filter. If
1374-
// none did, it calls onMiss to add a record with null right projections to
1375-
// the output.
1387+
// none did, it calls onMiss to add a record with null right projections
1388+
// to the output.
13761389
//
13771390
// Before we leave the current buffer, since we may not have seen the next
1378-
// left key match yet, the last key match may still be pending to produce a
1379-
// row (because `processFilterResult()` was not called yet).
1391+
// left key match yet, the last key match may still be pending to produce
1392+
// a row (because `processFilterResult()` was not called yet).
13801393
//
13811394
// To handle this, we need to call `noMoreFilterResults()` unless the
1382-
// same current left key match may continue in the next buffer. So there are
1383-
// two cases to check:
1395+
// same current left key match may continue in the next buffer. So there
1396+
// are two cases to check:
13841397
//
1385-
// 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1386-
// a different key match.
1398+
// 1. If leftMatch_ is nullopt, there for sure the next buffer will
1399+
// contain a different key match.
13871400
//
13881401
// 2. leftMatch_ may not be nullopt, but may be related to a different
13891402
// (subsequent) left key. So we check if the last row in the batch has the
13901403
// same left row number as the last key match.
13911404
if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) {
13921405
joinTracker_->noMoreFilterResults(onMiss);
13931406
}
1407+
13941408
} else {
13951409
filterRows_.resize(numRows);
13961410
filterRows_.setAll();

velox/exec/MergeJoin.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,12 @@ class MergeJoin : public Operator {
394394
// rows that correspond to a single left-side row. Use
395395
// 'noMoreFilterResults' to make sure 'onMiss' is called for the last
396396
// left-side row.
397-
template <typename TOnMiss>
397+
template <typename TOnMiss, typename TOnMatch>
398398
void processFilterResult(
399399
vector_size_t outputIndex,
400400
bool passed,
401-
TOnMiss onMiss) {
401+
TOnMiss onMiss,
402+
TOnMatch onMatch) {
402403
const auto rowNumber = rawLeftRowNumbers_[outputIndex];
403404
if (currentLeftRowNumber_ != rowNumber) {
404405
if (currentRow_ != -1 && !currentRowPassed_) {
@@ -407,12 +408,18 @@ class MergeJoin : public Operator {
407408
currentRow_ = outputIndex;
408409
currentLeftRowNumber_ = rowNumber;
409410
currentRowPassed_ = false;
411+
firstMatched_ = false;
410412
} else {
411413
currentRow_ = outputIndex;
412414
}
413415

414416
if (passed) {
415417
currentRowPassed_ = true;
418+
419+
if (!firstMatched_) {
420+
onMatch(outputIndex);
421+
firstMatched_ = true;
422+
}
416423
}
417424
}
418425

@@ -434,6 +441,7 @@ class MergeJoin : public Operator {
434441

435442
currentRow_ = -1;
436443
currentRowPassed_ = false;
444+
firstMatched_ = false;
437445
}
438446

439447
void reset();
@@ -470,6 +478,10 @@ class MergeJoin : public Operator {
470478
// True if at least one row in a block of output rows corresponding a single
471479
// left-side row identified by 'currentRowNumber' passed the filter.
472480
bool currentRowPassed_{false};
481+
482+
// Retains only the first matching record for a semi join in scenarios
483+
// involving filters.
484+
bool firstMatched_{false};
473485
};
474486

475487
/// Used to record both left and right join.

velox/exec/tests/MergeJoinTest.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,58 @@ TEST_F(MergeJoinTest, semiJoinWithMultipleMatchVectors) {
10161016
core::JoinType::kLeftSemiFilter);
10171017
}
10181018

1019+
TEST_F(MergeJoinTest, semiJoinWithMultiMatchedRowsWithFilter) {
1020+
auto left = makeRowVector(
1021+
{"t0", "t1"},
1022+
{makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2}),
1023+
makeNullableFlatVector<int64_t>({3, 2, 3, 2, 2})});
1024+
1025+
auto right = makeRowVector(
1026+
{"u0", "u1"},
1027+
{makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2, 2}),
1028+
makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2, 4})});
1029+
1030+
createDuckDbTable("t", {left});
1031+
createDuckDbTable("u", {right});
1032+
1033+
auto testSemiJoin = [&](const std::string& filter,
1034+
const std::string& sql,
1035+
const std::vector<std::string>& outputLayout,
1036+
core::JoinType joinType) {
1037+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
1038+
auto plan = PlanBuilder(planNodeIdGenerator)
1039+
.values(split(left, 2))
1040+
.mergeJoin(
1041+
{"t0"},
1042+
{"u0"},
1043+
PlanBuilder(planNodeIdGenerator)
1044+
.values(split(right, 2))
1045+
.planNode(),
1046+
filter,
1047+
outputLayout,
1048+
joinType)
1049+
.planNode();
1050+
AssertQueryBuilder(plan, duckDbQueryRunner_)
1051+
.config(core::QueryConfig::kPreferredOutputBatchRows, "2")
1052+
.config(core::QueryConfig::kMaxOutputBatchRows, "2")
1053+
.assertResults(sql);
1054+
};
1055+
1056+
// Left Semi join With filter
1057+
testSemiJoin(
1058+
"t1 > u1",
1059+
"SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 > u1)",
1060+
{"t0", "t1"},
1061+
core::JoinType::kLeftSemiFilter);
1062+
1063+
// Right Semi join With filter
1064+
testSemiJoin(
1065+
"u1 > t1",
1066+
"SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where u1 > t1)",
1067+
{"u0", "u1"},
1068+
core::JoinType::kRightSemiFilter);
1069+
}
1070+
10191071
TEST_F(MergeJoinTest, rightJoin) {
10201072
auto left = makeRowVector(
10211073
{"t0"},

0 commit comments

Comments
 (0)