diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 6d580bfebf32..d961a95b1022 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -608,13 +608,22 @@ bool MergeJoin::addToOutputForLeftJoin() { : rightMatch_->startRowIndex; const auto numRightBatches = rightMatch_->inputs.size(); - for (size_t r = firstRightBatch; r < numRightBatches; ++r) { + // TODO: Since semi joins only require determining if there is at least + // one match on the other side, we could explore specialized algorithms + // or data structures that short-circuit the join process once a match + // is found. + for (size_t r = isLeftSemiFilterJoin(joinType_) ? numRightBatches - 1 + : firstRightBatch; + r < numRightBatches; + ++r) { const auto rightBatch = rightMatch_->inputs[r]; - const auto rightStartRow = - r == firstRightBatch ? rightStartRowIndex : 0; - auto rightEndRow = r == numRightBatches - 1 ? rightMatch_->endRowIndex - : rightBatch->size(); - + auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0; + const auto rightEndRow = r == numRightBatches - 1 + ? rightMatch_->endRowIndex + : rightBatch->size(); + if (isLeftSemiFilterJoin(joinType_)) { + rightStartRow = rightEndRow - 1; + } if (prepareOutput(leftBatch, rightBatch)) { output_->resize(outputSize_); leftMatch_->setCursor(l, i); @@ -622,15 +631,6 @@ bool MergeJoin::addToOutputForLeftJoin() { return true; } - // TODO: Since semi joins only require determining if there is at least - // one match on the other side, we could explore specialized algorithms - // or data structures that short-circuit the join process once a match - // is found. - if (isLeftSemiFilterJoin(joinType_)) { - // LeftSemiFilter produce each row from the left at most once. - rightEndRow = rightStartRow + 1; - } - for (auto j = rightStartRow; j < rightEndRow; ++j) { if (!tryAddOutputRow(leftBatch, i, rightBatch, j)) { // If we run out of space in the current output_, we will need to @@ -688,11 +688,23 @@ bool MergeJoin::addToOutputForRightJoin() { : leftMatch_->startRowIndex; const auto numLeftBatches = leftMatch_->inputs.size(); - for (size_t l = firstLeftBatch; l < numLeftBatches; ++l) { + // TODO: Since semi joins only require determining if there is at least + // one match on the other side, we could explore specialized algorithms + // or data structures that short-circuit the join process once a match + // is found. + for (size_t l = isRightSemiFilterJoin(joinType_) ? numLeftBatches - 1 + : firstLeftBatch; + l < numLeftBatches; + ++l) { const auto leftBatch = leftMatch_->inputs[l]; - const auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0; - auto leftEndRow = l == numLeftBatches - 1 ? leftMatch_->endRowIndex - : leftBatch->size(); + auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0; + const auto leftEndRow = l == numLeftBatches - 1 + ? leftMatch_->endRowIndex + : leftBatch->size(); + if (isRightSemiFilterJoin(joinType_)) { + // RightSemiFilter produce each row from the right at most once. + leftStartRow = leftEndRow - 1; + } if (prepareOutput(leftBatch, rightBatch)) { // Differently from left joins, for right joins we need to load lazies @@ -706,15 +718,6 @@ bool MergeJoin::addToOutputForRightJoin() { return true; } - // TODO: Since semi joins only require determining if there is at least - // one match on the other side, we could explore specialized algorithms - // or data structures that short-circuit the join process once a match - // is found. - if (isRightSemiFilterJoin(joinType_)) { - // RightSemiFilter produce each row from the right at most once. - leftEndRow = leftStartRow + 1; - } - for (auto j = leftStartRow; j < leftEndRow; ++j) { if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) { // If we run out of space in the current output_, we will need to diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index c6351f554e33..f485948a50a9 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -917,6 +917,55 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, semiJoinWithMultipleMatchVectors) { + std::vector leftVectors; + for (int i = 0; i < 10; ++i) { + leftVectors.push_back(makeRowVector( + {"t0"}, {makeFlatVector({i / 2, i / 2, i / 2})})); + } + std::vector rightVectors; + for (int i = 0; i < 10; ++i) { + rightVectors.push_back(makeRowVector( + {"u0"}, {makeFlatVector({i / 2, i / 2, i / 2})})); + } + + createDuckDbTable("t", leftVectors); + createDuckDbTable("u", rightVectors); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(leftVectors) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(rightVectors) + .planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kMaxOutputBatchRows, "1") + .assertResults(sql); + }; + + testSemiJoin( + "u0 > 1", + "SELECT u0 FROM u where u0 IN (SELECT t0 from t) and u0 > 1", + {"u0"}, + core::JoinType::kRightSemiFilter); + testSemiJoin( + "t0 >1", + "SELECT t0 FROM t where t0 IN (SELECT u0 from u) and t0 > 1", + {"t0"}, + core::JoinType::kLeftSemiFilter); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"},