@@ -90,7 +90,9 @@ void MergeJoin::initialize() {
9090 initializeFilter (joinNode_->filter (), leftType, rightType);
9191
9292 if (joinNode_->isLeftJoin () || joinNode_->isAntiJoin () ||
93- joinNode_->isRightJoin () || joinNode_->isFullJoin ()) {
93+ joinNode_->isRightJoin () || joinNode_->isFullJoin () ||
94+ joinNode_->isLeftSemiFilterJoin () ||
95+ joinNode_->isRightSemiFilterJoin ()) {
9496 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
9597 }
9698 } else if (joinNode_->isAntiJoin ()) {
@@ -421,7 +423,7 @@ bool MergeJoin::tryAddOutputRow(
421423 filterRightInputProjections_);
422424
423425 if (joinTracker_) {
424- if (isRightJoin (joinType_)) {
426+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ) {
425427 // Record right-side row with a match on the left-side.
426428 joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
427429 } else {
@@ -613,16 +615,17 @@ bool MergeJoin::addToOutputForLeftJoin() {
613615 // one match on the other side, we could explore specialized algorithms
614616 // or data structures that short-circuit the join process once a match
615617 // is found.
616- for (size_t r = isLeftSemiFilterJoin (joinType_) ? numRightBatches - 1
617- : firstRightBatch;
618+ for (size_t r = (isLeftSemiFilterJoin (joinType_) && !filter_)
619+ ? numRightBatches - 1
620+ : firstRightBatch;
618621 r < numRightBatches;
619622 ++r) {
620623 const auto rightBatch = rightMatch_->inputs [r];
621624 auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0 ;
622625 const auto rightEndRow = r == numRightBatches - 1
623626 ? rightMatch_->endRowIndex
624627 : rightBatch->size ();
625- if (isLeftSemiFilterJoin (joinType_)) {
628+ if (isLeftSemiFilterJoin (joinType_) && !filter_ ) {
626629 rightStartRow = rightEndRow - 1 ;
627630 }
628631 if (prepareOutput (leftBatch, rightBatch)) {
@@ -693,16 +696,17 @@ bool MergeJoin::addToOutputForRightJoin() {
693696 // one match on the other side, we could explore specialized algorithms
694697 // or data structures that short-circuit the join process once a match
695698 // is found.
696- for (size_t l = isRightSemiFilterJoin (joinType_) ? numLeftBatches - 1
697- : firstLeftBatch;
699+ for (size_t l = (isRightSemiFilterJoin (joinType_) && !filter_)
700+ ? numLeftBatches - 1
701+ : firstLeftBatch;
698702 l < numLeftBatches;
699703 ++l) {
700704 const auto leftBatch = leftMatch_->inputs [l];
701705 auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0 ;
702706 const auto leftEndRow = l == numLeftBatches - 1
703707 ? leftMatch_->endRowIndex
704708 : leftBatch->size ();
705- if (isRightSemiFilterJoin (joinType_)) {
709+ if (isRightSemiFilterJoin (joinType_) && !filter_ ) {
706710 // RightSemiFilter produce each row from the right at most once.
707711 leftStartRow = leftEndRow - 1 ;
708712 }
@@ -818,7 +822,7 @@ RowVectorPtr MergeJoin::getOutput() {
818822 continue ;
819823 } else if (isAntiJoin (joinType_)) {
820824 output = filterOutputForAntiJoin (output);
821- if (output) {
825+ if (output != nullptr && output-> size () > 0 ) {
822826 return output;
823827 }
824828
@@ -1274,7 +1278,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12741278 // If all matches for a given left-side row fail the filter, add a row to
12751279 // the output with nulls for the right-side columns.
12761280 const auto onMiss = [&](auto row) {
1277- if (isAntiJoin (joinType_)) {
1281+ if (isAntiJoin (joinType_) || isLeftSemiFilterJoin (joinType_) ||
1282+ isRightSemiFilterJoin (joinType_)) {
12781283 return ;
12791284 }
12801285 rawIndices[numPassed++] = row;
@@ -1346,18 +1351,26 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13461351 }
13471352 };
13481353
1354+ auto onMatch = [&](auto row) {
1355+ if (isLeftSemiFilterJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
1356+ rawIndices[numPassed++] = row;
1357+ }
1358+ };
1359+
13491360 for (auto i = 0 ; i < numRows; ++i) {
13501361 if (filterRows.isValid (i)) {
13511362 const bool passed = !decodedFilterResult_.isNullAt (i) &&
13521363 decodedFilterResult_.valueAt <bool >(i);
13531364
1354- joinTracker_->processFilterResult (i, passed, onMiss);
1365+ joinTracker_->processFilterResult (i, passed, onMiss, onMatch );
13551366
13561367 if (isAntiJoin (joinType_)) {
13571368 if (!passed) {
13581369 rawIndices[numPassed++] = i;
13591370 }
1360- } else {
1371+ } else if (
1372+ !isLeftSemiFilterJoin (joinType_) &&
1373+ !isRightSemiFilterJoin (joinType_)) {
13611374 if (passed) {
13621375 rawIndices[numPassed++] = i;
13631376 }
@@ -1371,26 +1384,27 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13711384
13721385 // Every time we start a new left key match, `processFilterResult()` will
13731386 // check if at least one row from the previous match passed the filter. If
1374- // none did, it calls onMiss to add a record with null right projections to
1375- // the output.
1387+ // none did, it calls onMiss to add a record with null right projections
1388+ // to the output.
13761389 //
13771390 // Before we leave the current buffer, since we may not have seen the next
1378- // left key match yet, the last key match may still be pending to produce a
1379- // row (because `processFilterResult()` was not called yet).
1391+ // left key match yet, the last key match may still be pending to produce
1392+ // a row (because `processFilterResult()` was not called yet).
13801393 //
13811394 // To handle this, we need to call `noMoreFilterResults()` unless the
1382- // same current left key match may continue in the next buffer. So there are
1383- // two cases to check:
1395+ // same current left key match may continue in the next buffer. So there
1396+ // are two cases to check:
13841397 //
1385- // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1386- // a different key match.
1398+ // 1. If leftMatch_ is nullopt, there for sure the next buffer will
1399+ // contain a different key match.
13871400 //
13881401 // 2. leftMatch_ may not be nullopt, but may be related to a different
13891402 // (subsequent) left key. So we check if the last row in the batch has the
13901403 // same left row number as the last key match.
13911404 if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch (numRows - 1 )) {
13921405 joinTracker_->noMoreFilterResults (onMiss);
13931406 }
1407+
13941408 } else {
13951409 filterRows_.resize (numRows);
13961410 filterRows_.setAll ();
0 commit comments