diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java index 2289656bf93c..c9268eb3a211 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java @@ -58,8 +58,8 @@ import static io.prestosql.spi.function.InvocationConvention.simpleConvention; import static io.prestosql.spi.function.OperatorType.COMPARISON; import static io.prestosql.util.Failures.internalError; +import static io.prestosql.util.MinMaxCompare.getMinMaxCompare; import static io.prestosql.util.Reflection.methodHandle; -import static java.lang.invoke.MethodHandles.filterReturnValue; public abstract class AbstractMinMaxAggregationFunction extends SqlAggregationFunction @@ -79,10 +79,7 @@ public abstract class AbstractMinMaxAggregationFunction private static final MethodHandle BOOLEAN_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, NullableBooleanState.class, NullableBooleanState.class); private static final MethodHandle BLOCK_POSITION_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, BlockPositionState.class, BlockPositionState.class); - private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "min", long.class); - private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "max", long.class); - - private final MethodHandle comparisonResultAdapter; + private final boolean min; protected AbstractMinMaxAggregationFunction(String name, boolean min, String description) { @@ -103,7 +100,7 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des AGGREGATE), true, false); - this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION; + this.min = min; } @Override @@ -142,8 +139,9 @@ public InternalAggregationFunction specialize(FunctionBinding functionBinding, F else { invocationConvention = simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION); } - MethodHandle compareMethodHandle = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(type, type), Optional.of(invocationConvention)).getMethodHandle(); - compareMethodHandle = filterReturnValue(compareMethodHandle, comparisonResultAdapter); + + MethodHandle compareMethodHandle = getMinMaxCompare(functionDependencies, type, Optional.of(invocationConvention), min); + return generateAggregation(type, compareMethodHandle); } @@ -341,16 +339,4 @@ private static void compareAndUpdateState(MethodHandle methodHandle, BlockPositi throw internalError(t); } } - - @UsedByGeneratedCode - public static boolean min(long comparisonResult) - { - return comparisonResult < 0; - } - - @UsedByGeneratedCode - public static boolean max(long comparisonResult) - { - return comparisonResult > 0; - } } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/MaxNAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/MaxNAggregationFunction.java index 43c95f9d66ca..b5691e722656 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/MaxNAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/MaxNAggregationFunction.java @@ -15,6 +15,8 @@ import io.prestosql.type.BlockTypeOperators; +import static io.prestosql.util.MinMaxCompare.getMaxCompare; + public class MaxNAggregationFunction extends AbstractMinMaxNAggregationFunction { @@ -22,6 +24,8 @@ public class MaxNAggregationFunction public MaxNAggregationFunction(BlockTypeOperators blockTypeOperators) { - super(NAME, blockTypeOperators::getComparisonOperator, "Returns the maximum values of the argument"); + super(NAME, + type -> getMaxCompare(blockTypeOperators, type), + "Returns the maximum values of the argument"); } } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java index 384b11fd5db5..ba8b57cf8fb5 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java @@ -23,7 +23,6 @@ import io.airlift.bytecode.Parameter; import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; -import io.prestosql.annotation.UsedByGeneratedCode; import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; @@ -46,6 +45,7 @@ import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.gen.CallSiteBinder; import io.prestosql.sql.gen.SqlTypeBytecodeExpression; +import io.prestosql.util.MinMaxCompare; import java.lang.invoke.MethodHandle; import java.lang.reflect.Method; @@ -83,16 +83,12 @@ import static io.prestosql.util.CompilerUtils.defineClass; import static io.prestosql.util.CompilerUtils.makeClassName; import static io.prestosql.util.Reflection.methodHandle; -import static java.lang.invoke.MethodHandles.filterReturnValue; import static java.util.Arrays.stream; public abstract class AbstractMinMaxBy extends SqlAggregationFunction { - private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractMinMaxBy.class, "min", long.class); - private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractMinMaxBy.class, "max", long.class); - - private final MethodHandle comparisonResultAdapter; + private final boolean min; protected AbstractMinMaxBy(boolean min, String description) { @@ -115,7 +111,7 @@ protected AbstractMinMaxBy(boolean min, String description) AGGREGATE), true, false); - this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION; + this.min = min; } @Override @@ -182,8 +178,7 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key List inputTypes = ImmutableList.of(valueType, keyType); CallSiteBinder binder = new CallSiteBinder(); - MethodHandle compareMethod = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(keyType, keyType), Optional.empty()).getMethodHandle(); - compareMethod = filterReturnValue(compareMethod, comparisonResultAdapter); + MethodHandle compareMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, keyType, Optional.empty(), min); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), @@ -339,16 +334,4 @@ private static Method getMethod(Class stateClass, String name) .findFirst() .orElseThrow(() -> new IllegalArgumentException("State class does not have a method named " + name)); } - - @UsedByGeneratedCode - public static boolean min(long comparisonResult) - { - return comparisonResult < 0; - } - - @UsedByGeneratedCode - public static boolean max(long comparisonResult) - { - return comparisonResult > 0; - } } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/MaxByNAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/MaxByNAggregationFunction.java index 81d366d1597e..9fef15aba0a3 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/MaxByNAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/MaxByNAggregationFunction.java @@ -15,6 +15,8 @@ import io.prestosql.type.BlockTypeOperators; +import static io.prestosql.util.MinMaxCompare.getMaxCompare; + public class MaxByNAggregationFunction extends AbstractMinMaxByNAggregationFunction { @@ -22,6 +24,8 @@ public class MaxByNAggregationFunction public MaxByNAggregationFunction(BlockTypeOperators blockTypeOperators) { - super(NAME, blockTypeOperators::getComparisonOperator, "Returns the values of the first argument associated with the maximum values of the second argument"); + super(NAME, + type -> getMaxCompare(blockTypeOperators, type), + "Returns the values of the first argument associated with the maximum values of the second argument"); } } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java b/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java index 816f74b175ca..82950c161ff1 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java @@ -24,7 +24,6 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; import io.airlift.bytecode.instruction.LabelNode; -import io.prestosql.annotation.UsedByGeneratedCode; import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; @@ -65,8 +64,8 @@ import static io.prestosql.util.CompilerUtils.defineClass; import static io.prestosql.util.CompilerUtils.makeClassName; import static io.prestosql.util.Failures.checkCondition; +import static io.prestosql.util.MinMaxCompare.getMinMaxCompare; import static io.prestosql.util.Reflection.methodHandle; -import static java.lang.invoke.MethodHandles.filterReturnValue; import static java.lang.invoke.MethodType.methodType; import static java.util.Collections.nCopies; import static java.util.stream.Collectors.joining; @@ -74,10 +73,7 @@ public abstract class AbstractGreatestLeast extends SqlScalarFunction { - private static final MethodHandle MIN_FUNCTION = methodHandle(AbstractGreatestLeast.class, "min", long.class); - private static final MethodHandle MAX_FUNCTION = methodHandle(AbstractGreatestLeast.class, "max", long.class); - - private final MethodHandle comparisonResultAdapter; + private final boolean min; protected AbstractGreatestLeast(boolean min, String description) { @@ -95,7 +91,7 @@ protected AbstractGreatestLeast(boolean min, String description) true, description, SCALAR)); - this.comparisonResultAdapter = min ? MIN_FUNCTION : MAX_FUNCTION; + this.min = min; } @Override @@ -112,8 +108,7 @@ public ScalarFunctionImplementation specialize(FunctionBinding functionBinding, Type type = functionBinding.getTypeVariable("E"); checkArgument(type.isOrderable(), "Type must be orderable"); - MethodHandle compareMethod = functionDependencies.getOperatorInvoker(COMPARISON, ImmutableList.of(type, type), Optional.empty()).getMethodHandle(); - compareMethod = filterReturnValue(compareMethod, comparisonResultAdapter); + MethodHandle compareMethod = getMinMaxCompare(functionDependencies, type, Optional.empty(), min); List> javaTypes = IntStream.range(0, functionBinding.getArity()) .mapToObj(i -> wrap(type.getJavaType())) @@ -192,16 +187,4 @@ private Class generate(List> javaTypes, MethodHandle compareMethod) return defineClass(definition, Object.class, binder.getBindings(), new DynamicClassLoader(getClass().getClassLoader())); } - - @UsedByGeneratedCode - public static boolean min(long comparisonResult) - { - return comparisonResult < 0; - } - - @UsedByGeneratedCode - public static boolean max(long comparisonResult) - { - return comparisonResult > 0; - } } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayMaxFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayMaxFunction.java index 28e2a105fd63..efd970a0efeb 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayMaxFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayMaxFunction.java @@ -28,7 +28,10 @@ import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.prestosql.spi.function.OperatorType.COMPARISON; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.util.Failures.internalError; +import static java.lang.Float.intBitsToFloat; @ScalarFunction("array_max") @Description("Get maximum value of array") @@ -126,4 +129,58 @@ private static int findMaxArrayElement(MethodHandle compareMethodHandle, Block b throw internalError(t); } } + + @SqlType("double") + @SqlNullable + public static Double doubleTypeArrayMax(@SqlType("array(double)") Block block) + { + if (block.getPositionCount() == 0) { + return null; + } + int selectedPosition = -1; + for (int position = 0; position < block.getPositionCount(); position++) { + if (block.isNull(position)) { + return null; + } + if (selectedPosition < 0 || doubleGreater(DOUBLE.getDouble(block, position), DOUBLE.getDouble(block, selectedPosition))) { + selectedPosition = position; + } + } + return DOUBLE.getDouble(block, selectedPosition); + } + + private static boolean doubleGreater(double left, double right) + { + return (left > right) || Double.isNaN(right); + } + + @SqlType("real") + @SqlNullable + public static Long realTypeArrayMax(@SqlType("array(real)") Block block) + { + if (block.getPositionCount() == 0) { + return null; + } + int selectedPosition = -1; + for (int position = 0; position < block.getPositionCount(); position++) { + if (block.isNull(position)) { + return null; + } + if (selectedPosition < 0 || floatGreater(getReal(block, position), getReal(block, selectedPosition))) { + selectedPosition = position; + } + } + return REAL.getLong(block, selectedPosition); + } + + @SuppressWarnings("NumericCastThatLosesPrecision") + private static float getReal(Block block, int position) + { + return intBitsToFloat((int) REAL.getLong(block, position)); + } + + private static boolean floatGreater(float left, float right) + { + return (left > right) || Float.isNaN(right); + } } diff --git a/presto-main/src/main/java/io/prestosql/util/MinMaxCompare.java b/presto-main/src/main/java/io/prestosql/util/MinMaxCompare.java new file mode 100644 index 000000000000..cf122e7fd9c4 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/util/MinMaxCompare.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.util; + +import io.prestosql.annotation.UsedByGeneratedCode; +import io.prestosql.metadata.FunctionDependencies; +import io.prestosql.spi.function.InvocationConvention; +import io.prestosql.spi.type.Type; +import io.prestosql.type.BlockTypeOperators; +import io.prestosql.type.BlockTypeOperators.BlockPositionComparison; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; + +import static io.prestosql.spi.function.OperatorType.COMPARISON; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.util.Reflection.methodHandle; +import static java.lang.Float.intBitsToFloat; +import static java.lang.invoke.MethodHandles.filterReturnValue; + +public final class MinMaxCompare +{ + private static final MethodHandle MIN_FUNCTION = methodHandle(MinMaxCompare.class, "min", long.class); + private static final MethodHandle MAX_FUNCTION = methodHandle(MinMaxCompare.class, "max", long.class); + + public static final MethodHandle MAX_REAL_FUNCTION = methodHandle(MinMaxCompare.class, "maxReal", long.class, long.class); + public static final MethodHandle MAX_DOUBLE_FUNCTION = methodHandle(MinMaxCompare.class, "maxDouble", double.class, double.class); + + private MinMaxCompare() {} + + public static MethodHandle getMinMaxCompare(FunctionDependencies dependencies, Type type, Optional convention, boolean min) + { + if (!min && type.equals(REAL)) { + return MAX_REAL_FUNCTION; + } + if (!min && type.equals(DOUBLE)) { + return MAX_DOUBLE_FUNCTION; + } + MethodHandle handle = dependencies.getOperatorInvoker(COMPARISON, List.of(type, type), convention).getMethodHandle(); + return filterReturnValue(handle, min ? MIN_FUNCTION : MAX_FUNCTION); + } + + public static BlockPositionComparison getMaxCompare(BlockTypeOperators operators, Type type) + { + if (type.equals(REAL)) { + return (leftBlock, leftPosition, rightBlock, rightPosition) -> { + float left = toReal(REAL.getLong(leftBlock, leftPosition)); + float right = toReal(REAL.getLong(rightBlock, rightPosition)); + if (Float.isNaN(left) && Float.isNaN(right)) { + return 0; + } + if (Float.isNaN(left)) { + return -1; + } + if (Float.isNaN(right)) { + return 1; + } + return Float.compare(left, right); + }; + } + if (type.equals(DOUBLE)) { + return (leftBlock, leftPosition, rightBlock, rightPosition) -> { + double left = DOUBLE.getDouble(leftBlock, leftPosition); + double right = DOUBLE.getDouble(rightBlock, rightPosition); + if (Double.isNaN(left) && Double.isNaN(right)) { + return 0; + } + if (Double.isNaN(left)) { + return -1; + } + if (Double.isNaN(right)) { + return 1; + } + return Double.compare(left, right); + }; + } + return operators.getComparisonOperator(type); + } + + @UsedByGeneratedCode + public static boolean min(long comparisonResult) + { + return comparisonResult < 0; + } + + @UsedByGeneratedCode + public static boolean max(long comparisonResult) + { + return comparisonResult > 0; + } + + @UsedByGeneratedCode + public static boolean maxReal(long intLeft, long intRight) + { + float left = toReal(intLeft); + float right = toReal(intRight); + return (left > right) || Float.isNaN(right); + } + + @UsedByGeneratedCode + public static boolean maxDouble(double left, double right) + { + return (left > right) || Double.isNaN(right); + } + + @SuppressWarnings("NumericCastThatLosesPrecision") + private static float toReal(long value) + { + return intBitsToFloat((int) value); + } +} diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByAggregation.java index 409fe47e11ce..0fd2f7070825 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByAggregation.java @@ -28,6 +28,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.block.BlockAssertions.createArrayBigintBlock; +import static io.prestosql.block.BlockAssertions.createBlockOfReals; import static io.prestosql.block.BlockAssertions.createBooleansBlock; import static io.prestosql.block.BlockAssertions.createDoublesBlock; import static io.prestosql.block.BlockAssertions.createIntsBlock; @@ -42,6 +43,7 @@ import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.prestosql.type.UnknownType.UNKNOWN; @@ -226,6 +228,94 @@ public void testMaxDoubleVarchar() "hi", createStringsBlock("zz", "hi", null, "a"), createDoublesBlock(0.0, 1.0, null, -1.0)); + + assertAggregation( + function, + "c", + createStringsBlock("a", "b", "c"), + createDoublesBlock(Double.NaN, 1.0, 2.0)); + + assertAggregation( + function, + "c", + createStringsBlock("a", "b", "c"), + createDoublesBlock(1.0, Double.NaN, 2.0)); + + assertAggregation( + function, + "b", + createStringsBlock("a", "b", "c"), + createDoublesBlock(1.0, 2.0, Double.NaN)); + } + + @Test + public void testMinRealVarchar() + { + InternalAggregationFunction function = METADATA.getAggregateFunctionImplementation(METADATA.resolveFunction(QualifiedName.of("min_by"), fromTypes(VARCHAR, REAL))); + assertAggregation( + function, + "z", + createStringsBlock("z", "a", "x", "b"), + createBlockOfReals(1.0f, 2.0f, 2.0f, 3.0f)); + + assertAggregation( + function, + "a", + createStringsBlock("zz", "hi", "bb", "a"), + createBlockOfReals(0.0f, 1.0f, 2.0f, -1.0f)); + + assertAggregation( + function, + "b", + createStringsBlock("a", "b", "c"), + createBlockOfReals(Float.NaN, 1.0f, 2.0f)); + + assertAggregation( + function, + "a", + createStringsBlock("a", "b", "c"), + createBlockOfReals(1.0f, Float.NaN, 2.0f)); + + assertAggregation( + function, + "a", + createStringsBlock("a", "b", "c"), + createBlockOfReals(1.0f, 2.0f, Float.NaN)); + } + + @Test + public void testMaxRealVarchar() + { + InternalAggregationFunction function = METADATA.getAggregateFunctionImplementation(METADATA.resolveFunction(QualifiedName.of("max_by"), fromTypes(VARCHAR, REAL))); + assertAggregation( + function, + "a", + createStringsBlock("z", "a", null), + createBlockOfReals(1.0f, 2.0f, null)); + + assertAggregation( + function, + "hi", + createStringsBlock("zz", "hi", null, "a"), + createBlockOfReals(0.0f, 1.0f, null, -1.0f)); + + assertAggregation( + function, + "c", + createStringsBlock("a", "b", "c"), + createBlockOfReals(Float.NaN, 1.0f, 2.0f)); + + assertAggregation( + function, + "c", + createStringsBlock("a", "b", "c"), + createBlockOfReals(1.0f, Float.NaN, 2.0f)); + + assertAggregation( + function, + "b", + createStringsBlock("a", "b", "c"), + createBlockOfReals(1.0f, 2.0f, Float.NaN)); } @Test diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java index b431fc507a27..bc48feed862a 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java @@ -25,6 +25,7 @@ import java.util.Arrays; import static io.prestosql.block.BlockAssertions.createArrayBigintBlock; +import static io.prestosql.block.BlockAssertions.createBlockOfReals; import static io.prestosql.block.BlockAssertions.createDoublesBlock; import static io.prestosql.block.BlockAssertions.createLongsBlock; import static io.prestosql.block.BlockAssertions.createRLEBlock; @@ -34,6 +35,7 @@ import static io.prestosql.operator.aggregation.AggregationTestUtils.groupedAggregation; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.testng.Assert.assertEquals; @@ -156,6 +158,41 @@ public void testMinDoubleVarchar() createStringsBlock("zz", "hi", null, "a"), createDoublesBlock(0.0, 1.0, null, -1.0), createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("b", "c"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(Double.NaN, 2.0, 3.0, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "c"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, Double.NaN, 3.0, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, 2.0, Double.NaN, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, 2.0, 3.0, Double.NaN), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b"), + createDoublesBlock(1.0, Double.NaN), + createRLEBlock(2L, 2)); } @Test @@ -183,6 +220,165 @@ public void testMaxDoubleVarchar() createStringsBlock("zz", "hi", null, "a"), createDoublesBlock(0.0, 1.0, null, -1.0), createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "c"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(Double.NaN, 2.0, 3.0, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "c"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, Double.NaN, 3.0, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "b"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, 2.0, Double.NaN, 4.0), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("c", "b"), + createStringsBlock("a", "b", "c", "d"), + createDoublesBlock(1.0, 2.0, 3.0, Double.NaN), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b"), + createDoublesBlock(1.0, Double.NaN), + createRLEBlock(2L, 2)); + } + + @Test + public void testMinRealVarchar() + { + InternalAggregationFunction function = METADATA.getAggregateFunctionImplementation( + METADATA.resolveFunction(QualifiedName.of("min_by"), fromTypes(VARCHAR, REAL, BIGINT))); + assertAggregation( + function, + ImmutableList.of("z", "a"), + createStringsBlock("z", "a", "x", "b"), + createBlockOfReals(1.0f, 2.0f, 2.0f, 3.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "zz"), + createStringsBlock("zz", "hi", "bb", "a"), + createBlockOfReals(0.0f, 1.0f, 2.0f, -1.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "zz"), + createStringsBlock("zz", "hi", null, "a"), + createBlockOfReals(0.0f, 1.0f, null, -1.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("b", "c"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(Float.NaN, 2.0f, 3.0f, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "c"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, Float.NaN, 3.0f, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, 2.0f, Float.NaN, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, 2.0f, 3.0f, Float.NaN), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b"), + createBlockOfReals(1.0f, Float.NaN), + createRLEBlock(2L, 2)); + } + + @Test + public void testMaxRealVarchar() + { + InternalAggregationFunction function = METADATA.getAggregateFunctionImplementation( + METADATA.resolveFunction(QualifiedName.of("max_by"), fromTypes(VARCHAR, REAL, BIGINT))); + assertAggregation( + function, + ImmutableList.of("a", "z"), + createStringsBlock("z", "a", null), + createBlockOfReals(1.0f, 2.0f, null), + createRLEBlock(2L, 3)); + + assertAggregation( + function, + ImmutableList.of("bb", "hi"), + createStringsBlock("zz", "hi", "bb", "a"), + createBlockOfReals(0.0f, 1.0f, 2.0f, -1.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("hi", "zz"), + createStringsBlock("zz", "hi", null, "a"), + createBlockOfReals(0.0f, 1.0f, null, -1.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "c"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(Float.NaN, 2.0f, 3.0f, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "c"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, Float.NaN, 3.0f, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("d", "b"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, 2.0f, Float.NaN, 4.0f), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("c", "b"), + createStringsBlock("a", "b", "c", "d"), + createBlockOfReals(1.0f, 2.0f, 3.0f, Float.NaN), + createRLEBlock(2L, 4)); + + assertAggregation( + function, + ImmutableList.of("a", "b"), + createStringsBlock("a", "b"), + createBlockOfReals(1.0f, Float.NaN), + createRLEBlock(2L, 2)); } @Test diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/TestMathFunctions.java index 2691aed57db1..c8dcdb0eeeae 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/TestMathFunctions.java @@ -1169,15 +1169,25 @@ public void testGreatest() assertFunction("greatest(1.5E0, 2.3E0)", DOUBLE, 2.3); assertFunction("greatest(-1.5E0, -2.3E0)", DOUBLE, -1.5); assertFunction("greatest(-1.5E0, -2.3E0, -5/3)", DOUBLE, -1.0); - assertFunction("greatest(1.5E0, -1.0E0 / 0.0E0, 1.0E0 / 0.0E0)", DOUBLE, Double.POSITIVE_INFINITY); + assertFunction("greatest(1.5E0, -infinity(), infinity())", DOUBLE, Double.POSITIVE_INFINITY); assertFunction("greatest(5, 4, CAST(NULL as DOUBLE), 3)", DOUBLE, null); - - // float - assertFunction("greatest(REAL '1.5', 2.3E0)", DOUBLE, 2.3); - assertFunction("greatest(REAL '-1.5', -2.3E0)", DOUBLE, (double) -1.5f); - assertFunction("greatest(-1.5E0, REAL '-2.3', -5/3)", DOUBLE, -1.0); - assertFunction("greatest(REAL '1.5', REAL '-1.0' / 0.0E0, 1.0E0 / REAL '0.0')", DOUBLE, (double) (1.0f / 0.0f)); - assertFunction("greatest(5, REAL '4', CAST(NULL as DOUBLE), 3)", DOUBLE, null); + assertFunction("greatest(NaN(), 5, 4, 3)", DOUBLE, 5.0); + assertFunction("greatest(5, 4, NaN(), 3)", DOUBLE, 5.0); + assertFunction("greatest(5, 4, 3, NaN())", DOUBLE, 5.0); + assertFunction("greatest(NaN())", DOUBLE, Double.NaN); + assertFunction("greatest(NaN(), NaN(), NaN())", DOUBLE, Double.NaN); + + // real + assertFunction("greatest(REAL '1.5', REAL '2.3')", REAL, 2.3f); + assertFunction("greatest(REAL '-1.5', REAL '-2.3')", REAL, -1.5f); + assertFunction("greatest(REAL '-1.5', REAL '-2.3', CAST(-5/3 AS REAL))", REAL, -1.0f); + assertFunction("greatest(REAL '1.5', CAST(infinity() AS REAL))", REAL, Float.POSITIVE_INFINITY); + assertFunction("greatest(REAL '5', REAL '4', CAST(NULL as REAL), REAL '3')", REAL, null); + assertFunction("greatest(CAST(NaN() as REAL), REAL '5', REAL '4', REAL '3')", REAL, 5.0f); + assertFunction("greatest(REAL '5', REAL '4', CAST(NaN() as REAL), REAL '3')", REAL, 5.0f); + assertFunction("greatest(REAL '5', REAL '4', REAL '3', CAST(NaN() as REAL))", REAL, 5.0f); + assertFunction("greatest(CAST(NaN() as REAL))", REAL, Float.NaN); + assertFunction("greatest(CAST(NaN() as REAL), CAST(NaN() as REAL), CAST(NaN() as REAL))", REAL, Float.NaN); // decimal assertDecimalFunction("greatest(1.0, 2.0)", decimal("2.0")); @@ -1195,9 +1205,6 @@ public void testGreatest() assertFunction("greatest(1.0, 2.0E0)", DOUBLE, 2.0); assertDecimalFunction("greatest(5, 4, 3.0, 2)", decimal("0000000005.0")); - // NaN - assertFunction("greatest(1.5E0, 0.0E0 / 0.0E0)", DOUBLE, Double.NaN); - // argument count limit tryEvaluateWithAll("greatest(" + Joiner.on(", ").join(nCopies(127, "rand()")) + ")", DOUBLE); assertInvalidFunction( @@ -1241,15 +1248,25 @@ public void testLeast() assertFunction("least(1.5E0, 2.3E0)", DOUBLE, 1.5); assertFunction("least(-1.5E0, -2.3E0)", DOUBLE, -2.3); assertFunction("least(-1.5E0, -2.3E0, -5/3)", DOUBLE, -2.3); - assertFunction("least(1.5E0, -1.0E0 / 0.0E0, 1.0E0 / 0.0E0)", DOUBLE, Double.NEGATIVE_INFINITY); + assertFunction("least(1.5E0, -infinity(), infinity())", DOUBLE, Double.NEGATIVE_INFINITY); assertFunction("least(5, 4, CAST(NULL as DOUBLE), 3)", DOUBLE, null); - - // float - assertFunction("least(REAL '1.5', 2.3E0)", DOUBLE, (double) 1.5f); - assertFunction("least(REAL '-1.5', -2.3E0)", DOUBLE, -2.3); - assertFunction("least(-2.3E0, REAL '-0.4', -5/3)", DOUBLE, -2.3); - assertFunction("least(1.5E0, REAL '-1.0' / 0.0E0, 1.0E0 / 0.0E0)", DOUBLE, (double) (-1.0f / 0.0f)); - assertFunction("least(REAL '5', 4, CAST(NULL as DOUBLE), 3)", DOUBLE, null); + assertFunction("least(NaN(), 5, 4, 3)", DOUBLE, 3.0); + assertFunction("least(5, 4, NaN(), 3)", DOUBLE, 3.0); + assertFunction("least(5, 4, 3, NaN())", DOUBLE, 3.0); + assertFunction("least(NaN())", DOUBLE, Double.NaN); + assertFunction("least(NaN(), NaN(), NaN())", DOUBLE, Double.NaN); + + // real + assertFunction("least(REAL '1.5', REAL '2.3')", REAL, 1.5f); + assertFunction("least(REAL '-1.5', REAL '-2.3')", REAL, -2.3f); + assertFunction("least(REAL '-1.5', REAL '-2.3', CAST(-5/3 AS REAL))", REAL, -2.3f); + assertFunction("least(REAL '1.5', CAST(-infinity() AS REAL))", REAL, Float.NEGATIVE_INFINITY); + assertFunction("least(REAL '5', REAL '4', CAST(NULL as REAL), REAL '3')", REAL, null); + assertFunction("least(CAST(NaN() as REAL), REAL '5', REAL '4', REAL '3')", REAL, 3.0f); + assertFunction("least(REAL '5', REAL '4', CAST(NaN() as REAL), REAL '3')", REAL, 3.0f); + assertFunction("least(REAL '5', REAL '4', REAL '3', CAST(NaN() as REAL))", REAL, 3.0f); + assertFunction("least(CAST(NaN() as REAL))", REAL, Float.NaN); + assertFunction("least(CAST(NaN() as REAL), CAST(NaN() as REAL), CAST(NaN() as REAL))", REAL, Float.NaN); // decimal assertDecimalFunction("least(1.0, 2.0)", decimal("1.0")); @@ -1266,20 +1283,6 @@ public void testLeast() assertFunction("least(5.0E0, 4, CAST(NULL as BIGINT), 3)", DOUBLE, null); assertFunction("least(1.0, 2.0E0)", DOUBLE, 1.0); assertDecimalFunction("least(5, 4, 3.0, 2)", decimal("0000000002.0")); - - // NaN - assertFunction("least(1.5E0, 0.0E0 / 0.0E0)", DOUBLE, 1.5); - } - - @Test - public void testGreatestWithNaN() - { - assertFunction("greatest(1.5E0, 0.0E0 / 0.0E0)", DOUBLE, Double.NaN); - assertFunction("greatest(1.5E0, 0.0E0 / 0.0E0, 2.7E0)", DOUBLE, Double.NaN); - assertFunction("greatest(1.5E0, REAL '0.0' / REAL '0.0')", DOUBLE, Double.NaN); - assertFunction("greatest(1.5E0, REAL '0.0' / REAL '0.0', 2.7E0)", DOUBLE, Double.NaN); - assertFunction("greatest(null, REAL '0.0' / REAL '0.0')", REAL, null); - assertFunction("greatest(1.0E0 / 0.0E0, REAL '0.0' / REAL '0.0')", DOUBLE, Double.NaN); } @Test diff --git a/presto-main/src/test/java/io/prestosql/type/TestArrayOperators.java b/presto-main/src/test/java/io/prestosql/type/TestArrayOperators.java index f3493e118c08..74ceb7f15362 100644 --- a/presto-main/src/test/java/io/prestosql/type/TestArrayOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/TestArrayOperators.java @@ -595,14 +595,25 @@ public void testArrayMin() assertFunction("ARRAY_MIN(ARRAY [])", UNKNOWN, null); assertFunction("ARRAY_MIN(ARRAY [NULL])", UNKNOWN, null); assertFunction("ARRAY_MIN(ARRAY [NaN()])", DOUBLE, NaN); + assertFunction("ARRAY_MIN(ARRAY [CAST(NaN() AS REAL)])", REAL, Float.NaN); assertFunction("ARRAY_MIN(ARRAY [NULL, NULL, NULL])", UNKNOWN, null); assertFunction("ARRAY_MIN(ARRAY [NaN(), NaN(), NaN()])", DOUBLE, NaN); + assertFunction("ARRAY_MIN(ARRAY [CAST(NaN() AS REAL), CAST(NaN() AS REAL)])", REAL, Float.NaN); assertFunction("ARRAY_MIN(ARRAY [NULL, 2, 3])", INTEGER, null); assertFunction("ARRAY_MIN(ARRAY [NaN(), 2, 3])", DOUBLE, 2.0); + assertFunction("ARRAY_MIN(ARRAY [2, NaN(), 3])", DOUBLE, 2.0); + assertFunction("ARRAY_MIN(ARRAY [2, 3, NaN()])", DOUBLE, 2.0); assertFunction("ARRAY_MIN(ARRAY [NULL, NaN(), 1])", DOUBLE, null); assertFunction("ARRAY_MIN(ARRAY [NaN(), NULL, 3.0])", DOUBLE, null); assertFunction("ARRAY_MIN(ARRAY [1.0E0, NULL, 3])", DOUBLE, null); assertFunction("ARRAY_MIN(ARRAY [1.0, NaN(), 3])", DOUBLE, 1.0); + assertFunction("ARRAY_MIN(ARRAY [CAST(NaN() AS REAL), REAL '2', REAL '3'])", REAL, 2.0f); + assertFunction("ARRAY_MIN(ARRAY [REAL '2', CAST(NaN() AS REAL), REAL '3'])", REAL, 2.0f); + assertFunction("ARRAY_MIN(ARRAY [REAL '2', REAL '3', CAST(NaN() AS REAL)])", REAL, 2.0f); + assertFunction("ARRAY_MIN(ARRAY [NULL, CAST(NaN() AS REAL), REAL '1'])", REAL, null); + assertFunction("ARRAY_MIN(ARRAY [CAST(NaN() AS REAL), NULL, REAL '3'])", REAL, null); + assertFunction("ARRAY_MIN(ARRAY [REAL '1', NULL, REAL '3'])", REAL, null); + assertFunction("ARRAY_MIN(ARRAY [REAL '1', CAST(NaN() AS REAL), REAL '3'])", REAL, 1.0f); assertFunction("ARRAY_MIN(ARRAY ['1', '2', NULL])", createVarcharType(1), null); assertFunction("ARRAY_MIN(ARRAY [3, 2, 1])", INTEGER, 1); assertFunction("ARRAY_MIN(ARRAY [1, 2, 3])", INTEGER, 1); @@ -628,14 +639,25 @@ public void testArrayMax() assertFunction("ARRAY_MAX(ARRAY [])", UNKNOWN, null); assertFunction("ARRAY_MAX(ARRAY [NULL])", UNKNOWN, null); assertFunction("ARRAY_MAX(ARRAY [NaN()])", DOUBLE, NaN); + assertFunction("ARRAY_MAX(ARRAY [CAST(NaN() AS REAL)])", REAL, Float.NaN); assertFunction("ARRAY_MAX(ARRAY [NULL, NULL, NULL])", UNKNOWN, null); assertFunction("ARRAY_MAX(ARRAY [NaN(), NaN(), NaN()])", DOUBLE, NaN); + assertFunction("ARRAY_MAX(ARRAY [CAST(NaN() AS REAL), CAST(NaN() AS REAL)])", REAL, Float.NaN); assertFunction("ARRAY_MAX(ARRAY [NULL, 2, 3])", INTEGER, null); - assertFunction("ARRAY_MAX(ARRAY [NaN(), 2, 3])", DOUBLE, NaN); + assertFunction("ARRAY_MAX(ARRAY [NaN(), 2, 3])", DOUBLE, 3.0); + assertFunction("ARRAY_MAX(ARRAY [2, NaN(), 3])", DOUBLE, 3.0); + assertFunction("ARRAY_MAX(ARRAY [2, 3, NaN()])", DOUBLE, 3.0); assertFunction("ARRAY_MAX(ARRAY [NULL, NaN(), 1])", DOUBLE, null); assertFunction("ARRAY_MAX(ARRAY [NaN(), NULL, 3.0])", DOUBLE, null); assertFunction("ARRAY_MAX(ARRAY [1.0E0, NULL, 3])", DOUBLE, null); - assertFunction("ARRAY_MAX(ARRAY [1.0, NaN(), 3])", DOUBLE, NaN); + assertFunction("ARRAY_MAX(ARRAY [1.0, NaN(), 3])", DOUBLE, 3.0); + assertFunction("ARRAY_MAX(ARRAY [CAST(NaN() AS REAL), REAL '2', REAL '3'])", REAL, 3.0f); + assertFunction("ARRAY_MAX(ARRAY [REAL '2', CAST(NaN() AS REAL), REAL '3'])", REAL, 3.0f); + assertFunction("ARRAY_MAX(ARRAY [REAL '2', REAL '3', CAST(NaN() AS REAL)])", REAL, 3.0f); + assertFunction("ARRAY_MAX(ARRAY [NULL, CAST(NaN() AS REAL), REAL '1'])", REAL, null); + assertFunction("ARRAY_MAX(ARRAY [CAST(NaN() AS REAL), NULL, REAL '3'])", REAL, null); + assertFunction("ARRAY_MAX(ARRAY [REAL '1', NULL, REAL '3'])", REAL, null); + assertFunction("ARRAY_MAX(ARRAY [REAL '1', CAST(NaN() AS REAL), REAL '3'])", REAL, 3.0f); assertFunction("ARRAY_MAX(ARRAY ['1', '2', NULL])", createVarcharType(1), null); assertFunction("ARRAY_MAX(ARRAY [3, 2, 1])", INTEGER, 3); assertFunction("ARRAY_MAX(ARRAY [1, 2, 3])", INTEGER, 3); diff --git a/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java b/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java index 3df137e2a88f..d5b08e91e52c 100644 --- a/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java +++ b/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java @@ -1389,7 +1389,7 @@ public void testComputeFloatingPointStatistics(String dataType) : row("c_minmax", null, 2., 0.33333333333, null, "-1234567.9", "576234.56"), row("c_inf", null, 2., 0.33333333333, null, null, null), // -15, +inf row("c_ninf", null, 2., 0.33333333333, null, null, null), // -inf, 45 - row("c_nan", null, 2., 0.33333333333, null, null, null), // 12345., NaN + row("c_nan", null, 2., 0.33333333333, null, "12345.0", "12345.0"), // NaN is ignored by min/max row("c_nzero", null, 2., 0.33333333333, null, "-47.0", "0.0"), row(null, null, null, null, 3., null, null)); diff --git a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java index 0ca6f5b457af..89dd5440211a 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java @@ -1886,6 +1886,146 @@ public void testSpecialFloatingPointValues() assertEquals(row.getField(2), Double.NEGATIVE_INFINITY); } + @Test + public void testMinMaxFloatingPointNaN() + { + // double with NaN in first, middle, last, only + assertQuery( + "SELECT min(x), max(x) FROM (VALUES CAST(NaN() AS DOUBLE), DOUBLE '5.5', DOUBLE '3.3') t (x)", + "VALUES (CAST(3.3 AS DOUBLE), CAST(5.5 AS DOUBLE))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES DOUBLE '5.5', CAST(NaN() AS DOUBLE), DOUBLE '3.3') t (x)", + "VALUES (CAST(3.3 AS DOUBLE), CAST(5.5 AS DOUBLE))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES DOUBLE '5.5', DOUBLE '3.3', CAST(NaN() AS DOUBLE)) t (x)", + "VALUES (CAST(3.3 AS DOUBLE), CAST(5.5 AS DOUBLE))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES CAST(NaN() AS DOUBLE)) t (x)", + "VALUES (CAST(sqrt(-1) AS DOUBLE), CAST(sqrt(-1) AS DOUBLE))"); + + // real with NaN in first, middle, last, only + assertQuery( + "SELECT min(x), max(x) FROM (VALUES CAST(NaN() AS REAL), REAL '5.5', REAL '3.3') t (x)", + "VALUES (CAST(3.3 AS REAL), CAST(5.5 AS REAL))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES REAL '5.5', CAST(NaN() AS REAL), REAL '3.3') t (x)", + "VALUES (CAST(3.3 AS REAL), CAST(5.5 AS REAL))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES REAL '5.5', REAL '3.3', CAST(NaN() AS REAL)) t (x)", + "VALUES (CAST(3.3 AS REAL), CAST(5.5 AS REAL))"); + assertQuery( + "SELECT min(x), max(x) FROM (VALUES CAST(NaN() AS REAL)) t (x)", + "VALUES (CAST(sqrt(-1) AS REAL), CAST(sqrt(-1) AS REAL))"); + } + + @Test + public void testMinMaxNFloatingPointNaN() + { + // double with NaN in first, middle, last, only + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "CAST(NaN() AS DOUBLE), DOUBLE '5.5', DOUBLE '3.3', DOUBLE '4.4') t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS DOUBLE), CAST(4.4 AS DOUBLE)], " + + "ARRAY[CAST(5.5 AS DOUBLE), CAST(4.4 AS DOUBLE)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "DOUBLE '5.5', CAST(NaN() AS DOUBLE), DOUBLE '3.3', DOUBLE '4.4') t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS DOUBLE), CAST(4.4 AS DOUBLE)], " + + "ARRAY[CAST(5.5 AS DOUBLE), CAST(4.4 AS DOUBLE)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "DOUBLE '5.5', DOUBLE '3.3', DOUBLE '4.4', CAST(NaN() AS DOUBLE)) t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS DOUBLE), CAST(4.4 AS DOUBLE)], " + + "ARRAY[CAST(5.5 AS DOUBLE), CAST(4.4 AS DOUBLE)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "DOUBLE '8.8', CAST(NaN() AS DOUBLE)) t (x)", + "VALUES (" + + "ARRAY[CAST(8.8 AS DOUBLE), CAST(sqrt(-1) AS DOUBLE)], " + + "ARRAY[CAST(8.8 AS DOUBLE), CAST(sqrt(-1) AS DOUBLE)])"); + + // real with NaN in first, middle, last, only + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "CAST(NaN() AS REAL), REAL '5.5', REAL '3.3', REAL '4.4') t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS REAL), CAST(4.4 AS REAL)], " + + "ARRAY[CAST(5.5 AS REAL), CAST(4.4 AS REAL)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "REAL '5.5', CAST(NaN() AS REAL), REAL '3.3', REAL '4.4') t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS REAL), CAST(4.4 AS REAL)], " + + "ARRAY[CAST(5.5 AS REAL), CAST(4.4 AS REAL)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "REAL '5.5', REAL '3.3', REAL '4.4', CAST(NaN() AS REAL)) t (x)", + "VALUES (" + + "ARRAY[CAST(3.3 AS REAL), CAST(4.4 AS REAL)], " + + "ARRAY[CAST(5.5 AS REAL), CAST(4.4 AS REAL)])"); + assertQuery( + "SELECT min(x, 2), max(x, 2) FROM (VALUES " + + "REAL '8.8', CAST(NaN() AS REAL)) t (x)", + "VALUES (" + + "ARRAY[CAST(8.8 AS REAL), CAST(sqrt(-1) AS REAL)], " + + "ARRAY[CAST(8.8 AS REAL), CAST(sqrt(-1) AS REAL)])"); + } + + @Test + public void testMinMaxByFloatingPointNaN() + { + // double with NaN in first, middle, last, only + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', CAST(NaN() AS DOUBLE)), " + + "('b', DOUBLE '5.5'), " + + "('c', DOUBLE '3.3')) t (x, y)", + "VALUES ('c', 'b')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', DOUBLE '5.5'), " + + "('b', CAST(NaN() AS DOUBLE)), " + + "('c', DOUBLE '3.3')) t (x, y)", + "VALUES ('c', 'a')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', DOUBLE '5.5'), " + + "('b', DOUBLE '3.3'), " + + "('c', CAST(NaN() AS DOUBLE))) t (x, y)", + "VALUES ('b', 'a')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', CAST(NaN() AS DOUBLE))) t (x, y)", + "VALUES ('a', 'a')"); + + // real with NaN in first, middle, last, only + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', CAST(NaN() AS REAL)), " + + "('b', REAL '5.5'), " + + "('c', REAL '3.3')) t (x, y)", + "VALUES ('c', 'b')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', REAL '5.5'), " + + "('b', CAST(NaN() AS REAL)), " + + "('c', REAL '3.3')) t (x, y)", + "VALUES ('c', 'a')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', REAL '5.5'), " + + "('b', REAL '3.3'), " + + "('c', CAST(NaN() AS REAL))) t (x, y)", + "VALUES ('b', 'a')"); + assertQuery( + "SELECT min_by(x, y), max_by(x, y) FROM (VALUES" + + "('a', CAST(NaN() AS REAL))) t (x, y)", + "VALUES ('a', 'a')"); + } + @Test public void testOutputInEnforceSingleRow() {