Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 35 additions & 44 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,14 @@ inline void addNull(
target->setNull(index, true);
}

void MergeJoin::addOutputRowForLeftJoin(
const RowVectorPtr& leftBatch,
vector_size_t leftRow) {
bool MergeJoin::tryAddOutputRowForLeftJoin() {
VELOX_USER_CHECK(
isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_));
rawLeftOutputIndices_[outputSize_] = leftRow;
if (outputSize_ == outputBatchSize_) {
return false;
}

rawLeftOutputIndices_[outputSize_] = leftRowIndex_++;

for (const auto& projection : rightProjections_) {
auto& target = output_->childAt(projection.outputChannel);
Expand All @@ -335,13 +337,17 @@ void MergeJoin::addOutputRowForLeftJoin(
}

++outputSize_;

return true;
}

void MergeJoin::addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex) {
bool MergeJoin::tryAddOutputRowForRightJoin() {
VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_));
rawRightOutputIndices_[outputSize_] = rightIndex;
if (outputSize_ == outputBatchSize_) {
return false;
}

rawRightOutputIndices_[outputSize_] = rightRowIndex_++;

for (const auto& projection : leftProjections_) {
auto& target = output_->childAt(projection.outputChannel);
Expand All @@ -359,6 +365,8 @@ void MergeJoin::addOutputRowForRightJoin(
}

++outputSize_;

return true;
}

void MergeJoin::flattenRightProjections() {
Expand All @@ -374,11 +382,15 @@ void MergeJoin::flattenRightProjections() {
isRightFlattened_ = true;
}

void MergeJoin::addOutputRow(
bool MergeJoin::tryAddOutputRow(
const RowVectorPtr& leftBatch,
vector_size_t leftRow,
const RowVectorPtr& rightBatch,
vector_size_t rightRow) {
if (outputSize_ == outputBatchSize_) {
return false;
}

// All left side projections share the same dictionary indices (leftIndices_).
rawLeftOutputIndices_[outputSize_] = leftRow;

Expand Down Expand Up @@ -427,6 +439,8 @@ void MergeJoin::addOutputRow(
}

++outputSize_;

return true;
}

bool MergeJoin::prepareOutput(
Expand Down Expand Up @@ -618,7 +632,7 @@ bool MergeJoin::addToOutputForLeftJoin() {
}

for (auto j = rightStartRow; j < rightEndRow; ++j) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRow(leftBatch, i, rightBatch, j)) {
// If we run out of space in the current output_, we will need to
// produce a buffer and continue processing left later. In this
// case, we cannot leave left as a lazy vector, since we cannot have
Expand All @@ -628,7 +642,6 @@ bool MergeJoin::addToOutputForLeftJoin() {
rightMatch_->setCursor(r, j);
return true;
}
addOutputRow(leftBatch, i, rightBatch, j);
}
}
}
Expand Down Expand Up @@ -703,7 +716,7 @@ bool MergeJoin::addToOutputForRightJoin() {
}

for (auto j = leftStartRow; j < leftEndRow; ++j) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) {
// If we run out of space in the current output_, we will need to
// produce a buffer and continue processing left later. In this
// case, we cannot leave left as a lazy vector, since we cannot have
Expand All @@ -713,7 +726,6 @@ bool MergeJoin::addToOutputForRightJoin() {
leftMatch_->setCursor(l, j);
return true;
}
addOutputRow(leftBatch, j, rightBatch, i);
}
}
}
Expand Down Expand Up @@ -829,18 +841,10 @@ RowVectorPtr MergeJoin::getOutput() {
}

if (rightInput_) {
if (isFullJoin(joinType_)) {
if (isFullJoin(joinType_) || isRightJoin(joinType_)) {
rightRowIndex_ = 0;
} else {
const auto firstNonNullIndex =
firstNonNull(rightInput_, rightKeyChannels_);
if (isRightJoin(joinType_) && firstNonNullIndex > 0) {
prepareOutput(nullptr, rightInput_);
for (auto i = 0; i < firstNonNullIndex; ++i) {
addOutputRowForRightJoin(rightInput_, i);
}
}
rightRowIndex_ = firstNonNullIndex;
rightRowIndex_ = firstNonNull(rightInput_, rightKeyChannels_);
if (finishedRightBatch()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
Expand Down Expand Up @@ -907,7 +911,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
VELOX_CHECK(rightMatch_->complete);

if (rightMatch_->inputs.back() == rightInput_) {
if (isFullJoin(joinType_)) {
if (isFullJoin(joinType_) || isRightJoin(joinType_)) {
rightRowIndex_ = rightMatch_->endRowIndex;
} else {
rightRowIndex_ = firstNonNull(
Expand Down Expand Up @@ -946,11 +950,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
return std::move(output_);
}
while (true) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForLeftJoin()) {
return std::move(output_);
}
addOutputRowForLeftJoin(input_, leftRowIndex_);
++leftRowIndex_;

if (finishedLeftBatch()) {
input_ = nullptr;
Expand All @@ -974,12 +976,10 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

while (true) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForRightJoin()) {
return std::move(output_);
}
addOutputRowForRightJoin(rightInput_, rightRowIndex_);

++rightRowIndex_;
if (finishedRightBatch()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
Expand All @@ -1003,11 +1003,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

while (true) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForLeftJoin()) {
return std::move(output_);
}
addOutputRowForLeftJoin(input_, leftRowIndex_);
++leftRowIndex_;

if (finishedLeftBatch()) {
input_ = nullptr;
Expand All @@ -1031,13 +1029,10 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

while (true) {
if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForRightJoin()) {
return std::move(output_);
}

addOutputRowForRightJoin(rightInput_, rightRowIndex_);

++rightRowIndex_;
if (finishedRightBatch()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
Expand Down Expand Up @@ -1080,11 +1075,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
return std::move(output_);
}

if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForLeftJoin()) {
return std::move(output_);
}
addOutputRowForLeftJoin(input_, leftRowIndex_);
++leftRowIndex_;
} else {
leftRowIndex_ =
firstNonNull(input_, leftKeyChannels_, leftRowIndex_ + 1);
Expand All @@ -1107,11 +1100,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
return std::move(output_);
}

if (outputSize_ == outputBatchSize_) {
if (!tryAddOutputRowForRightJoin()) {
return std::move(output_);
}
addOutputRowForRightJoin(rightInput_, rightRowIndex_);
++rightRowIndex_;
} else {
rightRowIndex_ =
firstNonNull(rightInput_, rightKeyChannels_, rightRowIndex_ + 1);
Expand Down Expand Up @@ -1170,7 +1161,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

leftRowIndex_ = leftEndRow;
if (isFullJoin(joinType_)) {
if (isFullJoin(joinType_) || isRightJoin(joinType_)) {
rightRowIndex_ = endRightRow;
} else {
rightRowIndex_ =
Expand Down
35 changes: 20 additions & 15 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,14 @@ class MergeJoin : public Operator {
// right.
bool addToOutputForRightJoin();

// Adds one row of output by writing to the indices of the output
// Tries to add one row of output by writing to the indices of the output
// dictionaries. By default, this operator returns dictionaries wrapped around
// the input columns from the left and right. If `isRightFlattened_`, the
// right side projections are copied to the output.
//
// Advances outputSize_. Assumes that dictionary indices in output_ have room.
void addOutputRow(
// If there is space in the output, advances outputSize_ and returns true.
// Otherwise returns false and outputSize_ is unchanged.
bool tryAddOutputRow(
const RowVectorPtr& leftBatch,
vector_size_t leftRow,
const RowVectorPtr& rightBatch,
Expand All @@ -244,19 +245,23 @@ class MergeJoin : public Operator {
// logic is more involved.
void flattenRightProjections();

// Adds one row of output for a left-side row with no right-side match.
// Copies values from the 'leftIndex' row of 'left' and fills in nulls
// Tries to add one row of output for a left-side row with no right-side
// match. Copies values from the 'leftIndex' row of 'left' and fills in nulls
// for columns that correspond to the right side.
void addOutputRowForLeftJoin(
const RowVectorPtr& leftBatch,
vector_size_t leftRow);

// Adds one row of output for a right-side row with no left-side match.
// Copies values from the 'rightIndex' row of 'right' and fills in nulls
// for columns that correspond to the right side.
void addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex);
//
// If there is space in the output, advances outputSize_ and leftRowIndex_,
// and returns true. Otherwise returns false and outputSize_ and leftRowIndex_
// are unchanged.
bool tryAddOutputRowForLeftJoin();

// Tries to add one row of output for a right-side row with no left-side
// match. Copies values from the 'rightIndex' row of 'right' and fills in
// nulls for columns that correspond to the right side.
//
// If there is space in the output, advances outputSize_ and rightRowIndex_,
// and returns true. Otherwise returns false and outputSize_ and
// rightRowIndex_ are unchanged.
bool tryAddOutputRowForRightJoin();

// If all rows from the current left batch have been processed.
bool finishedLeftBatch() const {
Expand Down
69 changes: 58 additions & 11 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,30 @@ class MergeJoinTest : public HiveConnectorTestBase {
template <typename T>
void testJoin(
std::function<T(vector_size_t /*row*/)> leftKeyAt,
std::function<T(vector_size_t /*row*/)> rightKeyAt) {
std::function<T(vector_size_t /*row*/)> rightKeyAt,
std::function<bool(vector_size_t /*row*/)> leftNullAt = nullptr,
std::function<bool(vector_size_t /*row*/)> rightNullAt = nullptr) {
// Single batch on the left and right sides of the join.
{
auto leftKeys = makeFlatVector<T>(1'234, leftKeyAt);
auto rightKeys = makeFlatVector<T>(1'234, rightKeyAt);
auto leftKeys = makeFlatVector<T>(1'234, leftKeyAt, leftNullAt);
auto rightKeys = makeFlatVector<T>(1'234, rightKeyAt, rightNullAt);

testJoin({leftKeys}, {rightKeys});
}

// Multiple batches on one side. Single batch on the other side.
{
std::vector<VectorPtr> leftKeys = {
makeFlatVector<T>(1024, leftKeyAt),
makeFlatVector<T>(1024, leftKeyAt, leftNullAt),
makeFlatVector<T>(
1024, [&](auto row) { return leftKeyAt(1024 + row); }),
1024,
[&](auto row) { return leftKeyAt(1024 + row); },
[&](auto row) {
return leftNullAt ? leftNullAt(1024 + row) : false;
}),
};
std::vector<VectorPtr> rightKeys = {makeFlatVector<T>(2048, rightKeyAt)};
std::vector<VectorPtr> rightKeys = {
makeFlatVector<T>(2048, rightKeyAt, rightNullAt)};

testJoin(leftKeys, rightKeys);

Expand All @@ -157,18 +164,34 @@ class MergeJoinTest : public HiveConnectorTestBase {
// Multiple batches on each side.
{
std::vector<VectorPtr> leftKeys = {
makeFlatVector<T>(512, leftKeyAt),
makeFlatVector<T>(512, leftKeyAt, leftNullAt),
makeFlatVector<T>(
1024, [&](auto row) { return leftKeyAt(512 + row); }),
1024,
[&](auto row) { return leftKeyAt(512 + row); },
[&](auto row) {
return leftNullAt ? leftNullAt(512 + row) : false;
}),
makeFlatVector<T>(
16, [&](auto row) { return leftKeyAt(512 + 1024 + row); }),
16,
[&](auto row) { return leftKeyAt(512 + 1024 + row); },
[&](auto row) {
return leftNullAt ? leftNullAt(512 + 1024 + row) : false;
}),
};
std::vector<VectorPtr> rightKeys = {
makeFlatVector<T>(123, rightKeyAt),
makeFlatVector<T>(
1024, [&](auto row) { return rightKeyAt(123 + row); }),
1024,
[&](auto row) { return rightKeyAt(123 + row); },
[&](auto row) {
return rightNullAt ? rightNullAt(123 + row) : false;
}),
makeFlatVector<T>(
1234, [&](auto row) { return rightKeyAt(123 + 1024 + row); }),
1234,
[&](auto row) { return rightKeyAt(123 + 1024 + row); },
[&](auto row) {
return rightNullAt ? rightNullAt(123 + 1024 + row) : false;
}),
};

testJoin(leftKeys, rightKeys);
Expand Down Expand Up @@ -365,6 +388,30 @@ TEST_F(MergeJoinTest, duplicateMatch) {
[](auto row) { return row / 2; }, [](auto row) { return row / 3; });
}

TEST_F(MergeJoinTest, someNulls) {
testJoin<int32_t>(
[](auto row) { return row; },
[](auto row) { return row; },
[](auto row) { return row > 7; },
[](auto row) { return false; });
}

TEST_F(MergeJoinTest, someNullsOtherSideFinishesEarly) {
testJoin<int32_t>(
[](auto row) { return row; },
[](auto row) { return std::min(row, 7); },
[](auto row) { return row > 7; },
[](auto row) { return false; });
}

TEST_F(MergeJoinTest, someNullsOnBothSides) {
testJoin<int32_t>(
[](auto row) { return row; },
[](auto row) { return row; },
[](auto row) { return row > 7; },
[](auto row) { return row > 8; });
}

TEST_F(MergeJoinTest, allRowsMatch) {
std::vector<VectorPtr> leftKeys = {
makeFlatVector<int32_t>(2, [](auto /* row */) { return 5; }),
Expand Down
Loading