Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 91 additions & 92 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ void MergeJoin::initialize() {
isSemiFilterJoin(joinType_)) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
// right. Full outer join needs to track the right side rows that have no
// match on the left.
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}

Expand Down Expand Up @@ -383,7 +384,8 @@ bool MergeJoin::tryAddOutputRow(
const RowVectorPtr& leftBatch,
vector_size_t leftRow,
const RowVectorPtr& rightBatch,
vector_size_t rightRow) {
vector_size_t rightRow,
bool isRightJoinForFullOuter) {
if (outputSize_ == outputBatchSize_) {
return false;
}
Expand Down Expand Up @@ -417,12 +419,15 @@ bool MergeJoin::tryAddOutputRow(
filterRightInputProjections_);

if (joinTracker_) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
joinTracker_->addMatch(
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
} else {
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
joinTracker_->addMatch(
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
}
}
}
Expand All @@ -432,7 +437,8 @@ bool MergeJoin::tryAddOutputRow(
if (isAntiJoin(joinType_)) {
VELOX_CHECK(joinTracker_.has_value());
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
joinTracker_->addMatch(
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
}

++outputSize_;
Expand All @@ -450,14 +456,15 @@ bool MergeJoin::prepareOutput(
return true;
}

if (isRightJoin(joinType_) && right != currentRight_) {
return true;
}

// If there is a new right, we need to flatten the dictionary.
if (!isRightFlattened_ && right && currentRight_ != right) {
flattenRightProjections();
}

if (right != currentRight_) {
return true;
}

return false;
}

Expand All @@ -480,11 +487,15 @@ bool MergeJoin::prepareOutput(
}
} else {
for (const auto& projection : leftProjections_) {
auto column = left->childAt(projection.inputChannel);
// Flatten the left column if the column already is DictionaryVector.
if (column->wrappedVector()->encoding() ==
VectorEncoding::Simple::DICTIONARY) {
BaseVector::flattenVector(column);
}
column->clearContainingLazyAndWrapped();
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
{},
leftOutputIndices_,
outputBatchSize_,
left->childAt(projection.inputChannel));
{}, leftOutputIndices_, outputBatchSize_, column);
}
}
currentLeft_ = left;
Expand All @@ -500,11 +511,10 @@ bool MergeJoin::prepareOutput(
isRightFlattened_ = true;
} else {
for (const auto& projection : rightProjections_) {
auto column = right->childAt(projection.inputChannel);
column->clearContainingLazyAndWrapped();
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
{},
rightOutputIndices_,
outputBatchSize_,
right->childAt(projection.inputChannel));
{}, rightOutputIndices_, outputBatchSize_, column);
}
isRightFlattened_ = false;
}
Expand Down Expand Up @@ -568,6 +578,39 @@ bool MergeJoin::prepareOutput(
bool MergeJoin::addToOutput() {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
return addToOutputForRightJoin();
} else if (isFullJoin(joinType_) && filter_) {
if (!leftForRightJoinMatch_) {
leftForRightJoinMatch_ = leftMatch_;
rightForRightJoinMatch_ = rightMatch_;
}

if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
auto left = addToOutputForLeftJoin();
if (!leftMatch_) {
leftJoinForFullFinished_ = true;
}
if (left) {
if (!leftMatch_) {
leftMatch_ = leftForRightJoinMatch_;
rightMatch_ = rightForRightJoinMatch_;
}

return true;
}
}

if (!leftMatch_ && !rightJoinForFullFinished_) {
leftMatch_ = leftForRightJoinMatch_;
rightMatch_ = rightForRightJoinMatch_;
rightJoinForFullFinished_ = true;
}

auto right = addToOutputForRightJoin();

leftForRightJoinMatch_ = leftMatch_;
rightForRightJoinMatch_ = rightMatch_;

return right;
} else {
return addToOutputForLeftJoin();
}
Expand Down Expand Up @@ -660,7 +703,13 @@ bool MergeJoin::addToOutputImpl() {
} else {
for (auto innerRow = innerStartRow; innerRow < innerEndRow;
++innerRow) {
if (!tryAddOutputRow(leftBatch, innerRow, rightBatch, outerRow)) {
const auto isRightJoinForFullOuter = isFullJoin(joinType_);
if (!tryAddOutputRow(
leftBatch,
innerRow,
rightBatch,
outerRow,
isRightJoinForFullOuter)) {
outerMatch->setCursor(outerBatchIndex, outerRow);
innerMatch->setCursor(innerBatchIndex, innerRow);
return true;
Expand Down Expand Up @@ -931,7 +980,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
if (prepareOutput(input_, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}
Expand All @@ -956,7 +1005,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
if (prepareOutput(input_, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}
Expand Down Expand Up @@ -1003,6 +1052,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
endRightRow < rightInput_->size(),
std::nullopt};

leftJoinForFullFinished_ = false;
rightJoinForFullFinished_ = false;
if (!leftMatch_->complete || !rightMatch_->complete) {
if (!leftMatch_->complete) {
// Need to continue looking for the end of match.
Expand Down Expand Up @@ -1264,8 +1315,6 @@ void MergeJoin::clearRightInput() {
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const auto numRows = output->size();

RowVectorPtr fullOuterOutput = nullptr;

BufferPtr indices = allocateIndices(numRows, pool());
auto* rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;
Expand All @@ -1282,84 +1331,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
const auto onMiss = [&](auto row) {
const auto onMiss = [&](auto row, bool isRightJoinForFullOuter) {
if (isSemiFilterJoin(joinType_)) {
return;
}
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
// need to generate two records: one record containing the
// columns from the left table along with nulls for the
// right table, and another record containing the columns
// from the right table along with nulls for the left table.
// For instance, the current output is filtered based on the condition
// t > 1.

// 1, 1
// 2, 2
// 3, 3

// In this scenario, we need to additionally insert a record 1, 1.
// Subsequently, we will set the values of the columns on the left to
// null and the values of the columns on the right to null as well. By
// doing so, we will obtain the final result set.

// 1, null
// null, 1
// 2, 2
// 3, 3
fullOuterOutput = BaseVector::create<RowVector>(
output->type(), output->size() + 1, pool());

for (auto i = 0; i < row + 1; ++i) {
for (auto j = 0; j < output->type()->size(); ++j) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i, i, 1);
if (!isRightJoin(joinType_)) {
if (isFullJoin(joinType_) && isRightJoinForFullOuter) {
for (auto& projection : leftProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}

for (auto j = 0; j < output->type()->size(); ++j) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), row + 1, row, 1);
}

for (auto i = row + 1; i < output->size(); ++i) {
for (auto j = 0; j < output->type()->size(); ++j) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i + 1, i, 1);
} else {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}

for (auto& projection : leftProjections_) {
auto& target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row, true);
}

for (auto& projection : rightProjections_) {
auto& target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto& target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
} else {
for (auto& projection : leftProjections_) {
auto& target = output->childAt(projection.outputChannel);
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}
};

auto onMatch = [&](auto row, bool firstMatch) {
const bool isNonSemiAntiJoin =
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);
const bool isFullLeftJoin =
isFullJoin(joinType_) && !joinTracker_->isRightJoinForFullOuter(row);

const bool isNonSemiAntiFullJoin = !isSemiFilterJoin(joinType_) &&
!isAntiJoin(joinType_) && !isFullJoin(joinType_);

if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
if ((isSemiFilterJoin(joinType_) && firstMatch) ||
isNonSemiAntiFullJoin || isFullLeftJoin) {
rawIndices[numPassed++] = row;
}
};
Expand Down Expand Up @@ -1420,17 +1426,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (numPassed == numRows) {
// All rows passed.
if (fullOuterOutput) {
return fullOuterOutput;
}
return output;
}

// Some, but not all rows passed.
if (fullOuterOutput) {
return wrap(numPassed, indices, fullOuterOutput);
}

return wrap(numPassed, indices, output);
}

Expand Down
Loading
Loading