1818#include " velox/exec/Task.h"
1919#include " velox/expression/FieldReference.h"
2020
21+ #include < iostream>
22+
2123namespace facebook ::velox::exec {
2224
2325MergeJoin::MergeJoin (
@@ -92,7 +94,7 @@ void MergeJoin::initialize() {
9294 joinNode_->isRightJoin () || joinNode_->isFullJoin ()) {
9395 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
9496 }
95- } else if (joinNode_->isAntiJoin ()) {
97+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
9698 // Anti join needs to track the left side rows that have no match on the
9799 // right.
98100 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
@@ -386,7 +388,8 @@ bool MergeJoin::tryAddOutputRow(
386388 const RowVectorPtr& leftBatch,
387389 vector_size_t leftRow,
388390 const RowVectorPtr& rightBatch,
389- vector_size_t rightRow) {
391+ vector_size_t rightRow,
392+ bool isRightJoinForFullOuter) {
390393 if (outputSize_ == outputBatchSize_) {
391394 return false ;
392395 }
@@ -420,12 +423,15 @@ bool MergeJoin::tryAddOutputRow(
420423 filterRightInputProjections_);
421424
422425 if (joinTracker_) {
423- if (isRightJoin (joinType_)) {
426+ if (isRightJoin (joinType_) ||
427+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
424428 // Record right-side row with a match on the left-side.
425- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
429+ joinTracker_->addMatch (
430+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
426431 } else {
427432 // Record left-side row with a match on the right-side.
428- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
433+ joinTracker_->addMatch (
434+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
429435 }
430436 }
431437 }
@@ -435,7 +441,8 @@ bool MergeJoin::tryAddOutputRow(
435441 if (isAntiJoin (joinType_)) {
436442 VELOX_CHECK (joinTracker_.has_value ());
437443 // Record left-side row with a match on the right-side.
438- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
444+ joinTracker_->addMatch (
445+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
439446 }
440447
441448 ++outputSize_;
@@ -454,14 +461,14 @@ bool MergeJoin::prepareOutput(
454461 return true ;
455462 }
456463
457- if (isRightJoin (joinType_) && right != currentRight_) {
458- return true ;
459- }
460-
461464 // If there is a new right, we need to flatten the dictionary.
462465 if (!isRightFlattened_ && right && currentRight_ != right) {
463466 flattenRightProjections ();
464467 }
468+
469+ if (right != currentRight_) {
470+ return true ;
471+ }
465472 return false ;
466473 }
467474
@@ -573,6 +580,39 @@ bool MergeJoin::prepareOutput(
573580bool MergeJoin::addToOutput () {
574581 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
575582 return addToOutputForRightJoin ();
583+ } else if (isFullJoin (joinType_) && filter_) {
584+ if (!leftForRightJoinMatch_) {
585+ leftForRightJoinMatch_ = leftMatch_;
586+ rightForRightJoinMatch_ = rightMatch_;
587+ }
588+
589+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
590+ auto left = addToOutputForLeftJoin ();
591+ if (!leftMatch_) {
592+ leftJoinForFullFinished_ = true ;
593+ }
594+ if (left) {
595+ if (!leftMatch_) {
596+ leftMatch_ = leftForRightJoinMatch_;
597+ rightMatch_ = rightForRightJoinMatch_;
598+ }
599+
600+ return true ;
601+ }
602+ }
603+
604+ if (!leftMatch_ && !rightJoinForFullFinished_) {
605+ leftMatch_ = leftForRightJoinMatch_;
606+ rightMatch_ = rightForRightJoinMatch_;
607+ rightJoinForFullFinished_ = true ;
608+ }
609+
610+ auto right = addToOutputForRightJoin ();
611+
612+ leftForRightJoinMatch_ = leftMatch_;
613+ rightForRightJoinMatch_ = rightMatch_;
614+
615+ return right;
576616 } else {
577617 return addToOutputForLeftJoin ();
578618 }
@@ -719,7 +759,13 @@ bool MergeJoin::addToOutputForRightJoin() {
719759 }
720760
721761 for (auto j = leftStartRow; j < leftEndRow; ++j) {
722- if (!tryAddOutputRow (leftBatch, j, rightBatch, i)) {
762+ auto isRightJoinForFullOuter = false ;
763+ if (isFullJoin (joinType_)) {
764+ isRightJoinForFullOuter = true ;
765+ }
766+
767+ if (!tryAddOutputRow (
768+ leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
723769 // If we run out of space in the current output_, we will need to
724770 // produce a buffer and continue processing left later. In this
725771 // case, we cannot leave left as a lazy vector, since we cannot have
@@ -818,7 +864,7 @@ RowVectorPtr MergeJoin::getOutput() {
818864 continue ;
819865 } else if (isAntiJoin (joinType_)) {
820866 output = filterOutputForAntiJoin (output);
821- if (output) {
867+ if (output != nullptr && output-> size () > 0 ) {
822868 return output;
823869 }
824870
@@ -904,6 +950,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
904950 // results from the current match.
905951 if (addToOutput ()) {
906952 return std::move (output_);
953+ } else {
954+ previousLeftMatch_ = leftMatch_;
907955 }
908956 }
909957
@@ -968,6 +1016,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
9681016
9691017 if (addToOutput ()) {
9701018 return std::move (output_);
1019+ } else {
1020+ previousLeftMatch_ = leftMatch_;
9711021 }
9721022 }
9731023
@@ -1107,7 +1157,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11071157 isFullJoin (joinType_)) {
11081158 // If output_ is currently wrapping a different buffer, return it
11091159 // first.
1110- if (prepareOutput (input_, nullptr )) {
1160+ if (prepareOutput (input_, rightInput_ )) {
11111161 output_->resize (outputSize_);
11121162 return std::move (output_);
11131163 }
@@ -1132,7 +1182,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11321182 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
11331183 // If output_ is currently wrapping a different buffer, return it
11341184 // first.
1135- if (prepareOutput (nullptr , rightInput_)) {
1185+ if (prepareOutput (input_ , rightInput_)) {
11361186 output_->resize (outputSize_);
11371187 return std::move (output_);
11381188 }
@@ -1184,6 +1234,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
11841234 endRightRow < rightInput_->size (),
11851235 std::nullopt };
11861236
1237+ leftJoinForFullFinished_ = false ;
1238+ rightJoinForFullFinished_ = false ;
11871239 if (!leftMatch_->complete || !rightMatch_->complete ) {
11881240 if (!leftMatch_->complete ) {
11891241 // Need to continue looking for the end of match.
@@ -1212,6 +1264,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12121264
12131265 if (addToOutput ()) {
12141266 return std::move (output_);
1267+ } else {
1268+ previousLeftMatch_ = leftMatch_;
12151269 }
12161270
12171271 if (!rightInput_) {
@@ -1228,8 +1282,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12281282RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12291283 const auto numRows = output->size ();
12301284
1231- RowVectorPtr fullOuterOutput = nullptr ;
1232-
12331285 BufferPtr indices = allocateIndices (numRows, pool ());
12341286 auto * rawIndices = indices->asMutable <vector_size_t >();
12351287 vector_size_t numPassed = 0 ;
@@ -1246,76 +1298,29 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12461298
12471299 // If all matches for a given left-side row fail the filter, add a row to
12481300 // the output with nulls for the right-side columns.
1249- const auto onMiss = [&](auto row) {
1250- if (isAntiJoin (joinType_)) {
1251- return ;
1252- }
1253- rawIndices[numPassed++] = row;
1254-
1255- if (isFullJoin (joinType_)) {
1256- // For filtered rows, it is necessary to insert additional data
1257- // to ensure the result set is complete. Specifically, we
1258- // need to generate two records: one record containing the
1259- // columns from the left table along with nulls for the
1260- // right table, and another record containing the columns
1261- // from the right table along with nulls for the left table.
1262- // For instance, the current output is filtered based on the condition
1263- // t > 1.
1264-
1265- // 1, 1
1266- // 2, 2
1267- // 3, 3
1268-
1269- // In this scenario, we need to additionally insert a record 1, 1.
1270- // Subsequently, we will set the values of the columns on the left to
1271- // null and the values of the columns on the right to null as well. By
1272- // doing so, we will obtain the final result set.
1273-
1274- // 1, null
1275- // null, 1
1276- // 2, 2
1277- // 3, 3
1278- fullOuterOutput = BaseVector::create<RowVector>(
1279- output->type (), output->size () + 1 , pool ());
1280-
1281- for (auto i = 0 ; i < row + 1 ; ++i) {
1282- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1283- fullOuterOutput->childAt (j)->copy (
1284- output->childAt (j).get (), i, i, 1 );
1301+ auto onMiss = [&](auto row, bool flag) {
1302+ if (!isLeftSemiFilterJoin (joinType_) &&
1303+ !isRightSemiFilterJoin (joinType_)) {
1304+ rawIndices[numPassed++] = row;
1305+
1306+ if (!isRightJoin (joinType_)) {
1307+ if (isFullJoin (joinType_) && flag) {
1308+ for (auto & projection : leftProjections_) {
1309+ auto target = output->childAt (projection.outputChannel );
1310+ target->setNull (row, true );
1311+ }
1312+ } else {
1313+ for (auto & projection : rightProjections_) {
1314+ auto target = output->childAt (projection.outputChannel );
1315+ target->setNull (row, true );
1316+ }
12851317 }
1286- }
1287-
1288- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1289- fullOuterOutput->childAt (j)->copy (
1290- output->childAt (j).get (), row + 1 , row, 1 );
1291- }
1292-
1293- for (auto i = row + 1 ; i < output->size (); ++i) {
1294- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1295- fullOuterOutput->childAt (j)->copy (
1296- output->childAt (j).get (), i + 1 , i, 1 );
1318+ } else {
1319+ for (auto & projection : leftProjections_) {
1320+ auto target = output->childAt (projection.outputChannel );
1321+ target->setNull (row, true );
12971322 }
12981323 }
1299-
1300- for (auto & projection : leftProjections_) {
1301- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1302- target->setNull (row, true );
1303- }
1304-
1305- for (auto & projection : rightProjections_) {
1306- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1307- target->setNull (row + 1 , true );
1308- }
1309- } else if (!isRightJoin (joinType_)) {
1310- for (auto & projection : rightProjections_) {
1311- auto & target = output->childAt (projection.outputChannel );
1312- target->setNull (row, true );
1313- }
1314- } else {
1315- for (auto & projection : leftProjections_) {
1316- auto & target = output->childAt (projection.outputChannel );
1317- target->setNull (row, true );
1318- }
13191324 }
13201325 };
13211326
@@ -1326,12 +1331,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13261331
13271332 joinTracker_->processFilterResult (i, passed, onMiss);
13281333
1329- if (isAntiJoin (joinType_)) {
1330- if (!passed) {
1331- rawIndices[numPassed++] = i;
1332- }
1333- } else {
1334- if (passed) {
1334+ if (!isAntiJoin (joinType_)) {
1335+ if (passed && !joinTracker_->isRightJoinForFullOuter (i)) {
13351336 rawIndices[numPassed++] = i;
13361337 }
13371338 }
@@ -1344,26 +1345,30 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13441345
13451346 // Every time we start a new left key match, `processFilterResult()` will
13461347 // check if at least one row from the previous match passed the filter. If
1347- // none did, it calls onMiss to add a record with null right projections to
1348- // the output.
1348+ // none did, it calls onMiss to add a record with null right projections
1349+ // to the output.
13491350 //
13501351 // Before we leave the current buffer, since we may not have seen the next
1351- // left key match yet, the last key match may still be pending to produce a
1352- // row (because `processFilterResult()` was not called yet).
1352+ // left key match yet, the last key match may still be pending to produce
1353+ // a row (because `processFilterResult()` was not called yet).
13531354 //
13541355 // To handle this, we need to call `noMoreFilterResults()` unless the
1355- // same current left key match may continue in the next buffer. So there are
1356- // two cases to check:
1356+ // same current left key match may continue in the next buffer. So there
1357+ // are two cases to check:
13571358 //
1358- // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1359- // a different key match.
1359+ // 1. If leftMatch_ is nullopt, there for sure the next buffer will
1360+ // contain a different key match.
13601361 //
13611362 // 2. leftMatch_ may not be nullopt, but may be related to a different
13621363 // (subsequent) left key. So we check if the last row in the batch has the
13631364 // same left row number as the last key match.
13641365 if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch (numRows - 1 )) {
13651366 joinTracker_->noMoreFilterResults (onMiss);
13661367 }
1368+
1369+ if (isAntiJoin (joinType_) && leftMatch_ && !previousLeftMatch_) {
1370+ joinTracker_->noMoreFilterResults (onMiss);
1371+ }
13671372 } else {
13681373 filterRows_.resize (numRows);
13691374 filterRows_.setAll ();
@@ -1385,17 +1390,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13851390
13861391 if (numPassed == numRows) {
13871392 // All rows passed.
1388- if (fullOuterOutput) {
1389- return fullOuterOutput;
1390- }
13911393 return output;
13921394 }
13931395
13941396 // Some, but not all rows passed.
1395- if (fullOuterOutput) {
1396- return wrap (numPassed, indices, fullOuterOutput);
1397- }
1398-
13991397 return wrap (numPassed, indices, output);
14001398}
14011399
0 commit comments