diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileAggregations.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileAggregations.java index b05347961b05..85d188539d26 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileAggregations.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileAggregations.java @@ -25,6 +25,7 @@ import io.trino.spi.type.StandardTypes; import static com.google.common.base.Preconditions.checkState; +import static io.trino.operator.scalar.TDigestFunctions.verifyValue; import static io.trino.operator.scalar.TDigestFunctions.verifyWeight; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -38,6 +39,7 @@ private ApproximateDoublePercentileAggregations() {} @InputFunction public static void input(@AggregationState TDigestAndPercentileState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType(StandardTypes.DOUBLE) double percentile) { + verifyValue(value); TDigest digest = state.getDigest(); if (digest == null) { @@ -57,6 +59,7 @@ public static void input(@AggregationState TDigestAndPercentileState state, @Sql @InputFunction public static void weightedInput(@AggregationState TDigestAndPercentileState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType(StandardTypes.DOUBLE) double weight, @SqlType(StandardTypes.DOUBLE) double percentile) { + verifyValue(value); verifyWeight(weight); TDigest digest = state.getDigest(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileArrayAggregations.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileArrayAggregations.java index b9384fef2f5d..af3beeed4712 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileArrayAggregations.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateDoublePercentileArrayAggregations.java @@ -30,6 +30,7 @@ import java.util.List; +import static io.trino.operator.scalar.TDigestFunctions.verifyValue; import static io.trino.operator.scalar.TDigestFunctions.verifyWeight; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -43,6 +44,8 @@ private ApproximateDoublePercentileArrayAggregations() {} @InputFunction public static void input(@AggregationState TDigestAndPercentileArrayState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType("array(double)") Block percentilesArrayBlock) { + verifyValue(value); + initializePercentilesArray(state, percentilesArrayBlock); initializeDigest(state); @@ -55,6 +58,7 @@ public static void input(@AggregationState TDigestAndPercentileArrayState state, @InputFunction public static void weightedInput(@AggregationState TDigestAndPercentileArrayState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType(StandardTypes.DOUBLE) double weight, @SqlType("array(double)") Block percentilesArrayBlock) { + verifyValue(value); verifyWeight(weight); initializePercentilesArray(state, percentilesArrayBlock); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateLongPercentileAggregations.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateLongPercentileAggregations.java index 8a44bb37e499..f42ece1598b1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateLongPercentileAggregations.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateLongPercentileAggregations.java @@ -70,7 +70,7 @@ public static void output(@AggregationState TDigestAndPercentileState state, Blo public static double toDoubleExact(long value) { double doubleValue = (double) value; - checkCondition((long) doubleValue == value, INVALID_FUNCTION_ARGUMENT, "no exact double representation for long: %s", value); + checkCondition((long) doubleValue == value, INVALID_FUNCTION_ARGUMENT, () -> String.format("no exact double representation for long: %s", value)); return doubleValue; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/LegacyApproximateLongPercentileAggregations.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/LegacyApproximateLongPercentileAggregations.java index a70aa520efaf..bab5529093df 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/LegacyApproximateLongPercentileAggregations.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/LegacyApproximateLongPercentileAggregations.java @@ -15,6 +15,7 @@ import io.airlift.stats.QuantileDigest; import io.trino.operator.aggregation.state.QuantileDigestAndPercentileState; +import io.trino.spi.TrinoException; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -50,7 +51,7 @@ public static void weightedInput(@AggregationState QuantileDigestAndPercentileSt digest = new QuantileDigest(accuracy); } else { - throw new IllegalArgumentException("Percentile accuracy must be strictly between 0 and 1"); + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Percentile accuracy must be strictly between 0 and 1"); } state.setDigest(digest); state.addMemoryUsage(digest.estimatedInMemorySizeInBytes()); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayCombinationsFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayCombinationsFunction.java index 689447160644..6ea939faabe5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayCombinationsFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayCombinationsFunction.java @@ -56,8 +56,8 @@ public static Block combinations( { int arrayLength = array.getPositionCount(); int combinationLength = toIntExact(n); - checkCondition(combinationLength >= 0, INVALID_FUNCTION_ARGUMENT, "combination size must not be negative: %s", combinationLength); - checkCondition(combinationLength <= MAX_COMBINATION_LENGTH, INVALID_FUNCTION_ARGUMENT, "combination size must not exceed %s: %s", MAX_COMBINATION_LENGTH, combinationLength); + checkCondition(combinationLength >= 0, INVALID_FUNCTION_ARGUMENT, () -> String.format("combination size must not be negative: %s", combinationLength)); + checkCondition(combinationLength <= MAX_COMBINATION_LENGTH, INVALID_FUNCTION_ARGUMENT, () -> String.format("combination size must not exceed %s: %s", MAX_COMBINATION_LENGTH, combinationLength)); ArrayType arrayType = new ArrayType(elementType); if (combinationLength > arrayLength) { diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java index 9b253f1ba8be..d7b5a6436239 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java @@ -38,8 +38,8 @@ public static Block trim( @SqlType("array(E)") Block array, @SqlType(StandardTypes.BIGINT) long size) { - checkCondition(size >= 0, INVALID_FUNCTION_ARGUMENT, "size must not be negative: %s", size); - checkCondition(size <= array.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "size must not exceed array cardinality %s: %s", array.getPositionCount(), size); + checkCondition(size >= 0, INVALID_FUNCTION_ARGUMENT, () -> String.format("size must not be negative: %s", size)); + checkCondition(size <= array.getPositionCount(), INVALID_FUNCTION_ARGUMENT, () -> String.format("size must not exceed array cardinality %s: %s", array.getPositionCount(), size)); return array.getRegion(0, toIntExact(array.getPositionCount() - size)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/QuantileDigestFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/QuantileDigestFunctions.java index ba5e60e94640..fe30ff40897a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/QuantileDigestFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/QuantileDigestFunctions.java @@ -141,13 +141,13 @@ public static Block valuesAtQuantilesBigint(@SqlType("qdigest(bigint)") Slice in public static double verifyAccuracy(double accuracy) { - checkCondition(accuracy > 0 && accuracy < 1, INVALID_FUNCTION_ARGUMENT, "Percentile accuracy must be exclusively between 0 and 1, was %s", accuracy); + checkCondition(accuracy > 0 && accuracy < 1, INVALID_FUNCTION_ARGUMENT, () -> String.format("Percentile accuracy must be exclusively between 0 and 1, was %s", accuracy)); return accuracy; } public static long verifyWeight(long weight) { - checkCondition(weight > 0, INVALID_FUNCTION_ARGUMENT, "Percentile weight must be > 0, was %s", weight); + checkCondition(weight > 0, INVALID_FUNCTION_ARGUMENT, () -> String.format("Percentile weight must be > 0, was %s", weight)); return weight; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/TDigestFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/TDigestFunctions.java index b07ced731c52..fec032c252c3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/TDigestFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/TDigestFunctions.java @@ -59,9 +59,15 @@ public static Block valuesAtQuantiles(@SqlType(StandardTypes.TDIGEST) TDigest in return output.build(); } + public static void verifyValue(double value) + { + checkCondition(Double.isFinite(value), INVALID_FUNCTION_ARGUMENT, () -> String.format("value must be finite; was %s", value)); + } + public static double verifyWeight(double weight) { - checkCondition(weight >= 1, INVALID_FUNCTION_ARGUMENT, "weight must be >= 1, was %s", weight); + checkCondition(Double.isFinite(weight), INVALID_FUNCTION_ARGUMENT, () -> String.format("weight must be finite, was %s", weight)); + checkCondition(weight >= 1, INVALID_FUNCTION_ARGUMENT, () -> String.format("weight must be >= 1, was %s", weight)); return weight; } } diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java index 9ef4358c4fc6..b111c7704b42 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java @@ -573,7 +573,7 @@ public static Int128 jsonToLongDecimal(Slice json, long precision, long scale, I try (JsonParser parser = createJsonParser(JSON_FACTORY, json)) { parser.nextToken(); Int128 result = currentTokenAsLongDecimal(parser, intPrecision(precision), DecimalConversions.intScale(scale)); - checkCondition(parser.nextToken() == null, INVALID_CAST_ARGUMENT, "Cannot cast input json to DECIMAL(%s,%s)", precision, scale); // check no trailing token + checkCondition(parser.nextToken() == null, INVALID_CAST_ARGUMENT, () -> String.format("Cannot cast input json to DECIMAL(%s,%s)", precision, scale)); // check no trailing token return result; } catch (IOException | NumberFormatException | JsonCastException e) { @@ -587,7 +587,7 @@ public static Long jsonToShortDecimal(Slice json, long precision, long scale, lo try (JsonParser parser = createJsonParser(JSON_FACTORY, json)) { parser.nextToken(); Long result = currentTokenAsShortDecimal(parser, intPrecision(precision), DecimalConversions.intScale(scale)); - checkCondition(parser.nextToken() == null, INVALID_CAST_ARGUMENT, "Cannot cast input json to DECIMAL(%s,%s)", precision, scale); // check no trailing token + checkCondition(parser.nextToken() == null, INVALID_CAST_ARGUMENT, () -> String.format("Cannot cast input json to DECIMAL(%s,%s)", precision, scale)); // check no trailing token return result; } catch (IOException | NumberFormatException | JsonCastException e) { diff --git a/core/trino-main/src/main/java/io/trino/util/Failures.java b/core/trino-main/src/main/java/io/trino/util/Failures.java index 6b17e98c7562..477b019f765b 100644 --- a/core/trino-main/src/main/java/io/trino/util/Failures.java +++ b/core/trino-main/src/main/java/io/trino/util/Failures.java @@ -13,6 +13,7 @@ */ package io.trino.util; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.FormatMethod; import io.trino.client.ErrorLocation; @@ -60,9 +61,14 @@ public static ExecutionFailureInfo toFailure(Throwable failure) @FormatMethod public static void checkCondition(boolean condition, ErrorCodeSupplier errorCode, String formatString, Object... args) + { + checkCondition(condition, errorCode, () -> format(formatString, args)); + } + + public static void checkCondition(boolean condition, ErrorCodeSupplier errorCode, Supplier errorMessage) { if (!condition) { - throw new TrinoException(errorCode, format(formatString, args)); + throw new TrinoException(errorCode, errorMessage.get()); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java index c70805dd551d..07c4f419f652 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java @@ -24,6 +24,7 @@ import io.trino.spi.type.BooleanType; import io.trino.spi.type.Type; import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.testing.assertions.TrinoExceptionAssert; import org.apache.commons.math3.util.Precision; import java.util.Arrays; @@ -38,6 +39,7 @@ import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.fail; @@ -58,6 +60,11 @@ public static void assertAggregation(TestingFunctionResolution functionResolutio assertAggregation(functionResolution, name, parameterTypes, equalAssertion, null, page, expectedValue); } + public static TrinoExceptionAssert assertAggregationFails(TestingFunctionResolution functionResolution, String name, List parameterTypes, Block... blocks) + { + return assertTrinoExceptionThrownBy(() -> assertAggregation(functionResolution, name, parameterTypes, null, blocks)); + } + public static BiFunction makeValidityAssertion(Object expectedValue) { if (expectedValue instanceof Double && !expectedValue.equals(Double.NaN)) { diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximatePercentileAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximatePercentileAggregation.java index d4baf21a14e4..38553ba1593f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximatePercentileAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximatePercentileAggregation.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.TrinoException; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -33,6 +34,8 @@ import static io.trino.block.BlockAssertions.createLongsBlock; import static io.trino.block.BlockAssertions.createSequenceBlockOfReal; import static io.trino.operator.aggregation.AggregationTestUtils.assertAggregation; +import static io.trino.operator.aggregation.AggregationTestUtils.assertAggregationFails; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; @@ -442,6 +445,72 @@ public void testFloatPartialStep() createBlockOfReals(1.0f, 2.0f, 3.0f), createDoublesBlock(4.0, 2.0, 1.0), createRleBlock(ImmutableList.of(0.5, 0.8), 3)); + + // invalid inputs + for (Float invalidValue : List.of(Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY)) { + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE, + createBlockOfReals(invalidValue), + createRleBlock(0.5, 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_ARRAY, + createBlockOfReals(invalidValue), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED, + createBlockOfReals(1.0f), + createDoublesBlock((double) invalidValue), + createDoublesBlock(0.5)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED, + createBlockOfReals(invalidValue), + createDoublesBlock(1.0), + createRleBlock(0.5, 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED, + createBlockOfReals(invalidValue), + createDoublesBlock(1.0), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED, + createBlockOfReals(1.0f), + createDoublesBlock((double) invalidValue), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + // for deprecated approx_percentile with accuracy we only validate accuracy for backward compatibility + assertAggregationFails( + FUNCTION_RESOLUTION, + "approx_percentile", + FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_WITH_ACCURACY, + createBlockOfReals(1.0f), + createDoublesBlock(1.0), + createDoublesBlock(0.99), + createDoublesBlock((double) invalidValue)) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + } } @Test @@ -620,6 +689,72 @@ public void testDoublePartialStep() createDoublesBlock(1.0, 2.0, 3.0), createDoublesBlock(4.0, 2.0, 1.0), createRleBlock(ImmutableList.of(0.5, 0.8), 3)); + + // invalid inputs + for (Double invalidValue : List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)) { + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE, + createDoublesBlock(invalidValue), + createDoublesBlock(0.5)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_ARRAY, + createDoublesBlock(invalidValue), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED, + createDoublesBlock(invalidValue), + createDoublesBlock(1.0), + createRleBlock(0.5, 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED, + createDoublesBlock(1.0), + createDoublesBlock(invalidValue), + createRleBlock(0.5, 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED, + createDoublesBlock(1.0), + createDoublesBlock(invalidValue), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + assertAggregationFails(FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED, + createDoublesBlock(invalidValue), + createDoublesBlock(1.0), + createRleBlock(ImmutableList.of(0.5, 0.75), 1)) + .isInstanceOf(TrinoException.class) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + + // for deprecated approx_percentile with accuracy we only validate accuracy for backward compatibility + assertAggregationFails( + FUNCTION_RESOLUTION, + "approx_percentile", + DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_WITH_ACCURACY, + createDoublesBlock(1.0), + createDoublesBlock(1.0), + createDoublesBlock(0.99), + createDoublesBlock(invalidValue)) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT); + } } private static Block createRleBlock(double percentile, int positionCount)