Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@

namespace facebook::velox::exec {

namespace {
void copyRow(
const RowVectorPtr& source,
vector_size_t sourceIndex,
const RowVectorPtr& target,
vector_size_t targetIndex,
const std::vector<IdentityProjection>& projections) {
for (const auto& projection : projections) {
const auto& sourceChild = source->childAt(projection.inputChannel);
const auto& targetChild = target->childAt(projection.outputChannel);
targetChild->copy(sourceChild.get(), targetIndex, sourceIndex, 1);
}
}

bool isSemiFilterJoin(core::JoinType joinType) {
return isLeftSemiFilterJoin(joinType) || isRightSemiFilterJoin(joinType);
}
} // namespace

MergeJoin::MergeJoin(
int32_t operatorId,
DriverCtx* driverCtx,
Expand Down Expand Up @@ -90,7 +109,8 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
isSemiFilterJoin(joinType_)) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
Expand Down Expand Up @@ -274,21 +294,6 @@ bool MergeJoin::findEndOfMatch(
return true;
}

namespace {
void copyRow(
const RowVectorPtr& source,
vector_size_t sourceIndex,
const RowVectorPtr& target,
vector_size_t targetIndex,
const std::vector<IdentityProjection>& projections) {
for (const auto& projection : projections) {
const auto& sourceChild = source->childAt(projection.inputChannel);
const auto& targetChild = target->childAt(projection.outputChannel);
targetChild->copy(sourceChild.get(), targetIndex, sourceIndex, 1);
}
}
} // namespace

inline void addNull(
VectorPtr& target,
vector_size_t index,
Expand Down Expand Up @@ -421,7 +426,7 @@ bool MergeJoin::tryAddOutputRow(
filterRightInputProjections_);

if (joinTracker_) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
} else {
Expand Down Expand Up @@ -613,16 +618,17 @@ bool MergeJoin::addToOutputForLeftJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
for (size_t r = isLeftSemiFilterJoin(joinType_) ? numRightBatches - 1
: firstRightBatch;
for (size_t r = (isLeftSemiFilterJoin(joinType_) && !filter_)
? numRightBatches - 1
: firstRightBatch;
r < numRightBatches;
++r) {
const auto rightBatch = rightMatch_->inputs[r];
auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0;
const auto rightEndRow = r == numRightBatches - 1
? rightMatch_->endRowIndex
: rightBatch->size();
if (isLeftSemiFilterJoin(joinType_)) {
if (isLeftSemiFilterJoin(joinType_) && !filter_) {
rightStartRow = rightEndRow - 1;
}
if (prepareOutput(leftBatch, rightBatch)) {
Expand Down Expand Up @@ -693,16 +699,17 @@ bool MergeJoin::addToOutputForRightJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
for (size_t l = isRightSemiFilterJoin(joinType_) ? numLeftBatches - 1
: firstLeftBatch;
for (size_t l = (isRightSemiFilterJoin(joinType_) && !filter_)
? numLeftBatches - 1
: firstLeftBatch;
l < numLeftBatches;
++l) {
const auto leftBatch = leftMatch_->inputs[l];
auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0;
const auto leftEndRow = l == numLeftBatches - 1
? leftMatch_->endRowIndex
: leftBatch->size();
if (isRightSemiFilterJoin(joinType_)) {
if (isRightSemiFilterJoin(joinType_) && !filter_) {
// RightSemiFilter produce each row from the right at most once.
leftStartRow = leftEndRow - 1;
}
Expand Down Expand Up @@ -818,7 +825,7 @@ RowVectorPtr MergeJoin::getOutput() {
continue;
} else if (isAntiJoin(joinType_)) {
output = filterOutputForAntiJoin(output);
if (output) {
if (output != nullptr && output->size() > 0) {
return output;
}

Expand Down Expand Up @@ -1274,7 +1281,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
const auto onMiss = [&](auto row) {
if (isAntiJoin(joinType_)) {
if (isAntiJoin(joinType_) || isSemiFilterJoin(joinType_)) {
return;
}
rawIndices[numPassed++] = row;
Expand Down Expand Up @@ -1346,21 +1353,26 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
};

auto onMatch = [&](auto row, bool firstMatch) {
const bool isNonSemiAntiJoin =
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);

if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
rawIndices[numPassed++] = row;
}
};

for (auto i = 0; i < numRows; ++i) {
if (filterRows.isValid(i)) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);

joinTracker_->processFilterResult(i, passed, onMiss);
joinTracker_->processFilterResult(i, passed, onMiss, onMatch);

if (isAntiJoin(joinType_)) {
if (!passed) {
rawIndices[numPassed++] = i;
}
} else {
if (passed) {
rawIndices[numPassed++] = i;
}
}
} else {
// This row doesn't have a match on the right side. Keep it
Expand All @@ -1371,19 +1383,19 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

// Every time we start a new left key match, `processFilterResult()` will
// check if at least one row from the previous match passed the filter. If
// none did, it calls onMiss to add a record with null right projections to
// the output.
// none did, it calls onMiss to add a record with null right projections
// to the output.
//
// Before we leave the current buffer, since we may not have seen the next
// left key match yet, the last key match may still be pending to produce a
// row (because `processFilterResult()` was not called yet).
// left key match yet, the last key match may still be pending to produce
// a row (because `processFilterResult()` was not called yet).
//
// To handle this, we need to call `noMoreFilterResults()` unless the
// same current left key match may continue in the next buffer. So there are
// two cases to check:
// same current left key match may continue in the next buffer. So there
// are two cases to check:
//
// 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
// a different key match.
// 1. If leftMatch_ is nullopt, there for sure the next buffer will
// contain a different key match.
//
// 2. leftMatch_ may not be nullopt, but may be related to a different
// (subsequent) left key. So we check if the last row in the batch has the
Expand Down
6 changes: 4 additions & 2 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,12 @@ class MergeJoin : public Operator {
// rows that correspond to a single left-side row. Use
// 'noMoreFilterResults' to make sure 'onMiss' is called for the last
// left-side row.
template <typename TOnMiss>
template <typename TOnMiss, typename TOnMatch>
void processFilterResult(
vector_size_t outputIndex,
bool passed,
TOnMiss onMiss) {
const TOnMiss& onMiss,
const TOnMatch& onMatch) {
const auto rowNumber = rawLeftRowNumbers_[outputIndex];
if (currentLeftRowNumber_ != rowNumber) {
if (currentRow_ != -1 && !currentRowPassed_) {
Expand All @@ -412,6 +413,7 @@ class MergeJoin : public Operator {
}

if (passed) {
onMatch(outputIndex, /*firstMatch=*/!currentRowPassed_);
currentRowPassed_ = true;
}
}
Expand Down
104 changes: 104 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,110 @@ TEST_F(MergeJoinTest, semiJoinWithMultipleMatchVectors) {
core::JoinType::kLeftSemiFilter);
}

TEST_F(MergeJoinTest, semiJoinWithMultiMatchedRowsWithFilter) {
auto left = makeRowVector(
{"t0", "t1"},
{makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2}),
makeNullableFlatVector<int64_t>({3, 2, 3, 2, 2})});

auto right = makeRowVector(
{"u0", "u1"},
{makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2, 2}),
makeNullableFlatVector<int64_t>({2, 2, 2, 2, 2, 4})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.values(split(left, 2))
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values(split(right, 2))
.planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_)
.config(core::QueryConfig::kPreferredOutputBatchRows, "2")
.config(core::QueryConfig::kMaxOutputBatchRows, "2")
.assertResults(sql);
};

// Left Semi join With filter
testSemiJoin(
"t1 > u1",
"SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 > u1)",
{"t0", "t1"},
core::JoinType::kLeftSemiFilter);

// Right Semi join With filter
testSemiJoin(
"u1 > t1",
"SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where u1 > t1)",
{"u0", "u1"},
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, semiJoinWithOneMatchedRowWithFilter) {
auto left = makeRowVector(
{"t0", "t1"},
{makeNullableFlatVector<int64_t>({2, 2}),
makeNullableFlatVector<int64_t>({3, 5})});

auto right = makeRowVector(
{"u0", "u1"},
{makeNullableFlatVector<int64_t>({2, 2}),
makeNullableFlatVector<int64_t>({1, 4})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.values(split(left, 2))
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values(split(right, 2))
.planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_)
.config(core::QueryConfig::kPreferredOutputBatchRows, "2")
.config(core::QueryConfig::kMaxOutputBatchRows, "2")
.assertResults(sql);
};

// Left Semi join With filter
testSemiJoin(
"t1 > u1",
"SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 > u1)",
{"t0", "t1"},
core::JoinType::kLeftSemiFilter);

// Right Semi join With filter
testSemiJoin(
"u1 > t1",
"SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where u1 > t1)",
{"u0", "u1"},
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down
Loading