Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,29 +608,29 @@ bool MergeJoin::addToOutputForLeftJoin() {
: rightMatch_->startRowIndex;

const auto numRightBatches = rightMatch_->inputs.size();
for (size_t r = firstRightBatch; r < numRightBatches; ++r) {
// TODO: Since semi joins only require determining if there is at least
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
for (size_t r = isLeftSemiFilterJoin(joinType_) ? numRightBatches - 1
: firstRightBatch;
r < numRightBatches;
++r) {
const auto rightBatch = rightMatch_->inputs[r];
const auto rightStartRow =
r == firstRightBatch ? rightStartRowIndex : 0;
auto rightEndRow = r == numRightBatches - 1 ? rightMatch_->endRowIndex
: rightBatch->size();

auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0;
const auto rightEndRow = r == numRightBatches - 1
? rightMatch_->endRowIndex
: rightBatch->size();
if (isLeftSemiFilterJoin(joinType_)) {
rightStartRow = rightEndRow - 1;
}
if (prepareOutput(leftBatch, rightBatch)) {
output_->resize(outputSize_);
leftMatch_->setCursor(l, i);
rightMatch_->setCursor(r, rightStartRow);
return true;
}

// TODO: Since semi joins only require determining if there is at least
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_)) {
// LeftSemiFilter produce each row from the left at most once.
rightEndRow = rightStartRow + 1;
}

for (auto j = rightStartRow; j < rightEndRow; ++j) {
if (!tryAddOutputRow(leftBatch, i, rightBatch, j)) {
// If we run out of space in the current output_, we will need to
Expand Down Expand Up @@ -688,11 +688,23 @@ bool MergeJoin::addToOutputForRightJoin() {
: leftMatch_->startRowIndex;

const auto numLeftBatches = leftMatch_->inputs.size();
for (size_t l = firstLeftBatch; l < numLeftBatches; ++l) {
// TODO: Since semi joins only require determining if there is at least
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
for (size_t l = isRightSemiFilterJoin(joinType_) ? numLeftBatches - 1
: firstLeftBatch;
l < numLeftBatches;
++l) {
const auto leftBatch = leftMatch_->inputs[l];
const auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0;
auto leftEndRow = l == numLeftBatches - 1 ? leftMatch_->endRowIndex
: leftBatch->size();
auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0;
const auto leftEndRow = l == numLeftBatches - 1
? leftMatch_->endRowIndex
: leftBatch->size();
if (isRightSemiFilterJoin(joinType_)) {
// RightSemiFilter produce each row from the right at most once.
leftStartRow = leftEndRow - 1;
}

if (prepareOutput(leftBatch, rightBatch)) {
// Differently from left joins, for right joins we need to load lazies
Expand All @@ -706,15 +718,6 @@ bool MergeJoin::addToOutputForRightJoin() {
return true;
}

// TODO: Since semi joins only require determining if there is at least
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isRightSemiFilterJoin(joinType_)) {
// RightSemiFilter produce each row from the right at most once.
leftEndRow = leftStartRow + 1;
}

for (auto j = leftStartRow; j < leftEndRow; ++j) {
if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) {
// If we run out of space in the current output_, we will need to
Expand Down
49 changes: 49 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,55 @@ TEST_F(MergeJoinTest, semiJoin) {
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, semiJoinWithMultipleMatchVectors) {
std::vector<RowVectorPtr> leftVectors;
for (int i = 0; i < 10; ++i) {
leftVectors.push_back(makeRowVector(
{"t0"}, {makeFlatVector<int64_t>({i / 2, i / 2, i / 2})}));
}
std::vector<RowVectorPtr> rightVectors;
for (int i = 0; i < 10; ++i) {
rightVectors.push_back(makeRowVector(
{"u0"}, {makeFlatVector<int64_t>({i / 2, i / 2, i / 2})}));
}

createDuckDbTable("t", leftVectors);
createDuckDbTable("u", rightVectors);

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.values(leftVectors)
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.values(rightVectors)
.planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_)
.config(core::QueryConfig::kMaxOutputBatchRows, "1")
.assertResults(sql);
};

testSemiJoin(
"u0 > 1",
"SELECT u0 FROM u where u0 IN (SELECT t0 from t) and u0 > 1",
{"u0"},
core::JoinType::kRightSemiFilter);
testSemiJoin(
"t0 >1",
"SELECT t0 FROM t where t0 IN (SELECT u0 from u) and t0 > 1",
{"t0"},
core::JoinType::kLeftSemiFilter);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down
Loading