Skip to content

Commit 5a746f7

Browse files
mbasmanovameta-codesync[bot]
authored andcommitted
feat: Enhance ExprToSubfieldFilterParser::makeOrFilter (facebookincubator#15564)
Summary: Pull Request resolved: facebookincubator#15564 - Detect overlapping ranges of bigint and floating point values. - Detect a list of single-value bigint filters and combine these into a single IN list. Reviewed By: Yuhta Differential Revision: D87438547 fbshipit-source-id: 3fa3e9763f1043ed33bed2e05747aa0f7f007892
1 parent d757c52 commit 5a746f7

File tree

5 files changed

+386
-31
lines changed

5 files changed

+386
-31
lines changed

velox/expression/ExprToSubfieldFilter.cpp

Lines changed: 234 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -439,35 +439,249 @@ std::unique_ptr<common::Filter> ExprToSubfieldFilterParser::makeBetweenFilter(
439439
}
440440
}
441441

442-
// static
443-
std::unique_ptr<common::Filter> ExprToSubfieldFilterParser::makeOrFilter(
444-
std::unique_ptr<common::Filter> a,
445-
std::unique_ptr<common::Filter> b) {
446-
if (isBigintRange(a) && isBigintRange(b)) {
447-
return bigintOr(asBigintRange(a), asBigintRange(b));
442+
namespace {
443+
444+
bool isNullAllowed(
445+
const std::vector<std::unique_ptr<common::Filter>>& disjuncts) {
446+
return std::any_of(
447+
disjuncts.begin(), disjuncts.end(), [](const auto& filter) {
448+
return filter->nullAllowed();
449+
});
450+
}
451+
452+
// Combines overlapping ranges into one using OR semantic. Returns nullptr if
453+
// ranges do not overlap. Ignores nullAllowed flag.
454+
// @pre a.lower() <= b.lower()
455+
std::unique_ptr<common::BigintRange> tryMergeOverlappingRanges(
456+
const common::BigintRange& a,
457+
const common::BigintRange& b,
458+
bool& alwaysTrue) {
459+
static constexpr auto kMax = std::numeric_limits<int64_t>::max();
460+
static constexpr auto kMin = std::numeric_limits<int64_t>::min();
461+
462+
if (a.upper() == kMax || a.upper() + 1 >= b.lower()) {
463+
if (a.lower() == kMin && (a.upper() == kMax || b.upper() == kMax)) {
464+
alwaysTrue = true;
465+
return nullptr;
466+
}
467+
468+
return std::make_unique<common::BigintRange>(
469+
a.lower(), std::max(a.upper(), b.upper()), /*nullAllowed=*/false);
470+
}
471+
return nullptr;
472+
}
473+
474+
// Returns a single range that represents "a OR b" or nullptr if no such range
475+
// exists.
476+
// @pre a.lower() <= b.lower()
477+
template <typename T>
478+
std::unique_ptr<common::FloatingPointRange<T>> tryMergeOverlappingRanges(
479+
const common::FloatingPointRange<T>& a,
480+
const common::FloatingPointRange<T>& b,
481+
bool& alwaysTrue) {
482+
if (!a.upperUnbounded() && !b.lowerUnbounded() &&
483+
(a.upper() < b.lower() ||
484+
(a.upper() == b.lower() && a.upperExclusive() && b.lowerExclusive()))) {
485+
return nullptr;
448486
}
449487

450-
if (isBigintRange(a) && isBigintMultiRange(b)) {
451-
std::vector<std::unique_ptr<common::BigintRange>> newRanges;
452-
newRanges.emplace_back(asBigintRange(a));
453-
for (const auto& range : b->as<common::BigintMultiRange>()->ranges()) {
454-
newRanges.emplace_back(std::make_unique<common::BigintRange>(*range));
488+
const bool lowerUnbounded = a.lowerUnbounded() || b.lowerUnbounded();
489+
const bool upperUnbounded = a.upperUnbounded() || b.upperUnbounded();
490+
491+
const T lower = lowerUnbounded ? std::numeric_limits<T>::lowest()
492+
: std::min(a.lower(), b.lower());
493+
494+
bool lowerExclusive = lowerUnbounded;
495+
if (!lowerUnbounded) {
496+
if (a.lower() < b.lower()) {
497+
lowerExclusive = a.lowerExclusive();
498+
} else {
499+
lowerExclusive = a.lowerExclusive() && b.lowerExclusive();
455500
}
501+
}
456502

457-
std::sort(
458-
newRanges.begin(), newRanges.end(), [](const auto& a, const auto& b) {
459-
return a->lower() < b->lower();
460-
});
503+
const T upper = upperUnbounded ? std::numeric_limits<T>::max()
504+
: std::max(a.upper(), b.upper());
461505

462-
return std::make_unique<common::BigintMultiRange>(
463-
std::move(newRanges), false);
506+
bool upperExclusive = upperUnbounded;
507+
if (!upperUnbounded) {
508+
if (a.upper() > b.upper()) {
509+
upperExclusive = a.upperExclusive();
510+
} else if (a.upper() < b.upper()) {
511+
upperExclusive = b.upperExclusive();
512+
} else {
513+
upperExclusive = a.upperExclusive() && b.upperExclusive();
514+
}
515+
}
516+
517+
if (lowerUnbounded && upperUnbounded) {
518+
alwaysTrue = true;
519+
return nullptr;
520+
}
521+
522+
return std::make_unique<common::FloatingPointRange<T>>(
523+
lower,
524+
lowerUnbounded,
525+
lowerExclusive,
526+
upper,
527+
upperUnbounded,
528+
upperExclusive,
529+
/*nullAllowed=*/false);
530+
}
531+
532+
template <typename T, typename TToMultiRange>
533+
std::unique_ptr<common::Filter> mergeOverlappingDisjuncts(
534+
std::vector<std::unique_ptr<T>>& ranges,
535+
bool nullAllowed,
536+
const TToMultiRange& toMultiRange) {
537+
std::vector<std::unique_ptr<T>> newRanges;
538+
newRanges.emplace_back(asUniquePtr<T>(ranges.front()->clone(nullAllowed)));
539+
540+
for (auto i = 1; i < ranges.size(); i++) {
541+
bool alwaysTrue = false;
542+
if (auto merged = tryMergeOverlappingRanges(
543+
*newRanges.back(), *ranges[i], alwaysTrue)) {
544+
newRanges.back() = std::move(merged);
545+
} else {
546+
if (alwaysTrue) {
547+
if (nullAllowed) {
548+
return std::make_unique<common::AlwaysTrue>();
549+
}
550+
return isNotNull();
551+
}
552+
newRanges.emplace_back(std::move(ranges[i]));
553+
}
554+
}
555+
556+
if (newRanges.size() == 1) {
557+
return std::move(newRanges.front());
558+
}
559+
560+
return toMultiRange(newRanges, nullAllowed);
561+
}
562+
563+
std::unique_ptr<common::Filter> tryMergeBigintRanges(
564+
std::vector<std::unique_ptr<common::Filter>>& disjuncts) {
565+
// Check if all filters are single-value equalities: a = 5. Convert these to
566+
// an IN list.
567+
if (std::all_of(disjuncts.begin(), disjuncts.end(), [](const auto& filter) {
568+
return isBigintRange(filter) &&
569+
filter->template as<common::BigintRange>()->isSingleValue();
570+
})) {
571+
std::vector<int64_t> values;
572+
values.reserve(disjuncts.size());
573+
574+
for (auto& filter : disjuncts) {
575+
values.emplace_back(filter->as<common::BigintRange>()->lower());
576+
}
577+
578+
return common::createBigintValues(values, isNullAllowed(disjuncts));
464579
}
465580

466-
if (isBigintMultiRange(a) && isBigintRange(b)) {
467-
return makeOrFilter(std::move(b), std::move(a));
581+
if (!std::all_of(disjuncts.begin(), disjuncts.end(), [](const auto& filter) {
582+
return isBigintRange(filter) || isBigintMultiRange(filter);
583+
})) {
584+
return nullptr;
585+
}
586+
587+
const bool nullAllowed = isNullAllowed(disjuncts);
588+
589+
std::vector<std::unique_ptr<common::BigintRange>> ranges;
590+
for (auto& filter : disjuncts) {
591+
if (isBigintRange(filter)) {
592+
ranges.emplace_back(asBigintRange(filter));
593+
} else {
594+
for (const auto& range :
595+
filter->as<common::BigintMultiRange>()->ranges()) {
596+
ranges.emplace_back(std::make_unique<common::BigintRange>(*range));
597+
}
598+
}
599+
}
600+
601+
std::sort(ranges.begin(), ranges.end(), [](const auto& a, const auto& b) {
602+
return a->lower() < b->lower();
603+
});
604+
605+
return mergeOverlappingDisjuncts(
606+
ranges, nullAllowed, [](auto& newRanges, bool nullAllowed) {
607+
return std::make_unique<common::BigintMultiRange>(
608+
std::move(newRanges), nullAllowed);
609+
});
610+
}
611+
612+
template <typename T>
613+
std::unique_ptr<common::Filter> tryMergeFloatingPointRanges(
614+
std::vector<std::unique_ptr<common::Filter>>& disjuncts) {
615+
constexpr auto filterKind = std::is_same_v<T, double>
616+
? common::FilterKind::kDoubleRange
617+
: common::FilterKind::kFloatRange;
618+
619+
if (!std::all_of(disjuncts.begin(), disjuncts.end(), [](const auto& filter) {
620+
return filter->is(filterKind);
621+
})) {
622+
return nullptr;
623+
}
624+
625+
const bool nullAllowed = isNullAllowed(disjuncts);
626+
627+
std::vector<std::unique_ptr<common::FloatingPointRange<T>>> ranges;
628+
ranges.reserve(disjuncts.size());
629+
for (auto& filter : disjuncts) {
630+
ranges.emplace_back(
631+
asUniquePtr<common::FloatingPointRange<T>>(std::move(filter)));
468632
}
469633

470-
return orFilter(std::move(a), std::move(b));
634+
std::sort(ranges.begin(), ranges.end(), [](const auto& a, const auto& b) {
635+
if (a->lowerUnbounded() && b->lowerUnbounded()) {
636+
return false;
637+
}
638+
639+
if (a->lowerUnbounded()) {
640+
return true;
641+
}
642+
643+
if (b->lowerUnbounded()) {
644+
return false;
645+
}
646+
647+
return a->lower() < b->lower();
648+
});
649+
650+
return mergeOverlappingDisjuncts(
651+
ranges, nullAllowed, [](auto& newRanges, bool nullAllowed) {
652+
std::vector<std::unique_ptr<common::Filter>> filters;
653+
filters.reserve(newRanges.size());
654+
for (auto& range : newRanges) {
655+
filters.emplace_back(std::move(range));
656+
}
657+
return std::make_unique<common::MultiRange>(
658+
std::move(filters), nullAllowed);
659+
});
660+
}
661+
662+
} // namespace
663+
664+
// static
665+
std::unique_ptr<common::Filter> ExprToSubfieldFilterParser::makeOrFilter(
666+
std::vector<std::unique_ptr<common::Filter>> disjuncts) {
667+
VELOX_CHECK_GE(disjuncts.size(), 2);
668+
669+
if (auto merged = tryMergeBigintRanges(disjuncts)) {
670+
return merged;
671+
}
672+
673+
if (auto merged = tryMergeFloatingPointRanges<double>(disjuncts)) {
674+
return merged;
675+
}
676+
677+
if (auto merged = tryMergeFloatingPointRanges<float>(disjuncts)) {
678+
return merged;
679+
}
680+
681+
const bool nullAllowed = isNullAllowed(disjuncts);
682+
683+
return std::make_unique<common::MultiRange>(
684+
std::move(disjuncts), nullAllowed);
471685
}
472686

473687
namespace {

velox/expression/ExprToSubfieldFilter.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ inline std::unique_ptr<common::IsNotNull> isNotNull() {
338338
return std::make_unique<common::IsNotNull>();
339339
}
340340

341+
inline std::unique_ptr<common::AlwaysTrue> alwaysTrue() {
342+
return std::make_unique<common::AlwaysTrue>();
343+
}
344+
341345
template <typename T>
342346
std::unique_ptr<common::MultiRange>
343347
orFilter(std::unique_ptr<T> a, std::unique_ptr<T> b, bool nullAllowed = false) {
@@ -461,9 +465,21 @@ class ExprToSubfieldFilterParser {
461465
core::ExpressionEvaluator* evaluator,
462466
bool negated = false) = 0;
463467

468+
/// Combines 2 or more filters with an OR.
469+
/// Detects overlapping ranges of bigint and floating point values.
470+
/// Detects a list of single-value bigint filters and combines them into a
471+
/// single IN list.
464472
static std::unique_ptr<common::Filter> makeOrFilter(
465-
std::unique_ptr<common::Filter> a,
466-
std::unique_ptr<common::Filter> b);
473+
std::vector<std::unique_ptr<common::Filter>> disjuncts);
474+
475+
template <typename... Disjuncts>
476+
static std::unique_ptr<common::Filter> makeOrFilter(
477+
Disjuncts&&... disjuncts) {
478+
std::vector<std::unique_ptr<common::Filter>> filters;
479+
filters.reserve(sizeof...(Disjuncts));
480+
(filters.emplace_back(std::forward<Disjuncts>(disjuncts)), ...);
481+
return makeOrFilter(std::move(filters));
482+
}
467483

468484
protected:
469485
// Converts an expression into a subfield. Returns false if the expression

0 commit comments

Comments
 (0)