diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp index 150a80d4df3c7..78d88e85fd1d7 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp @@ -214,20 +214,13 @@ int64_t dateToInt64( return value.value(); } -double toDouble( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto variant = exprConverter.getConstantValue(type, *block); - return variant.value(); -} - -float toFloat( +template +T toFloatingPoint( const std::shared_ptr& block, const VeloxExprConverter& exprConverter, const TypePtr& type) { auto variant = exprConverter.getConstantValue(type, *block); - return variant.value(); + return variant.value(); } std::string toString( @@ -393,47 +386,54 @@ std::unique_ptr boolRangeToFilter( VELOX_UNREACHABLE(); } -std::unique_ptr doubleRangeToFilter( +template +std::unique_ptr floatingPointRangeToFilter( const protocol::Range& range, bool nullAllowed, const VeloxExprConverter& exprConverter, const TypePtr& type) { bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - auto low = lowUnbounded ? std::numeric_limits::lowest() - : toDouble(range.low.valueBlock, exprConverter, type); + auto low = lowUnbounded + ? (-1.0 * std::numeric_limits::infinity()) + : toFloatingPoint(range.low.valueBlock, exprConverter, type); bool highExclusive = range.high.bound == protocol::Bound::BELOW; bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; auto high = highUnbounded - ? std::numeric_limits::max() - : toDouble(range.high.valueBlock, exprConverter, type); - return std::make_unique( - low, - lowUnbounded, - lowExclusive, - high, - highUnbounded, - highExclusive, - nullAllowed); -} + ? std::numeric_limits::infinity() + : toFloatingPoint(range.high.valueBlock, exprConverter, type); -std::unique_ptr floatRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; - bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - auto low = lowUnbounded ? std::numeric_limits::lowest() - : toFloat(range.low.valueBlock, exprConverter, type); + // Handle NaN cases as NaN is not supported as a limit in Velox Filters + if (!lowUnbounded && std::isnan(low)) { + if (lowExclusive) { + // x > NaN is always false as NaN is considered the largest value. + return std::make_unique(); + } + // Equivalent to x > infinity as only NaN is greater than infinity + // Presto currently converts x >= NaN into the filter with domain + // [NaN, max), so ignoring the high value is fine. + low = std::numeric_limits::infinity(); + lowExclusive = true; + high = std::numeric_limits::infinity(); + highUnbounded = true; + highExclusive = false; + } else if (!highUnbounded && std::isnan(high)) { + high = std::numeric_limits::infinity(); + if (highExclusive) { + // equivalent to x in [low , infinity] or (low , infinity] + highExclusive = false; + } else { + if (lowUnbounded) { + // Anything <= NaN is true as NaN is the largest possible value. + return std::make_unique(); + } + // Equivalent to x > low or x >=low + highUnbounded = true; + } + } - bool highExclusive = range.high.bound == protocol::Bound::BELOW; - bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; - auto high = highUnbounded - ? std::numeric_limits::max() - : toFloat(range.high.valueBlock, exprConverter, type); - return std::make_unique( + return std::make_unique>( low, lowUnbounded, lowExclusive, @@ -653,14 +653,16 @@ std::unique_ptr toFilter( case TypeKind::HUGEINT: return hugeintRangeToFilter(range, nullAllowed, exprConverter, type); case TypeKind::DOUBLE: - return doubleRangeToFilter(range, nullAllowed, exprConverter, type); + return floatingPointRangeToFilter( + range, nullAllowed, exprConverter, type); case TypeKind::VARCHAR: case TypeKind::VARBINARY: return varcharRangeToFilter(range, nullAllowed, exprConverter, type); case TypeKind::BOOLEAN: return boolRangeToFilter(range, nullAllowed, exprConverter, type); case TypeKind::REAL: - return floatRangeToFilter(range, nullAllowed, exprConverter, type); + return floatingPointRangeToFilter( + range, nullAllowed, exprConverter, type); case TypeKind::TIMESTAMP: return timestampRangeToFilter(range, nullAllowed, exprConverter, type); default: