From 2195285991e122274e8afa1f7f493bc5d14a5526 Mon Sep 17 00:00:00 2001 From: lukasz-stec Date: Tue, 20 Jun 2023 10:30:26 +0200 Subject: [PATCH 1/2] Add returnsZeroOnEmptyInput to AggregationFunctionMetadata COUNT, COUNT_IF, and APPROX_DISTINCT do have not standard behavior when no input was supplied to the function. They return the value 0, as opposed to standard NULL. Before this change, OptimizeMixedDistinctAggregations relied on matching string function name to handle this case, but this is brittle and may cause silent correctness issues if a new function with this characteristic is added. --- .../io/trino/metadata/MetadataManager.java | 4 ++++ .../AggregationFromAnnotationsParser.java | 1 + .../aggregation/AggregationHeader.java | 10 +++++++++- .../ApproximateCountDistinctAggregation.java | 2 +- ...anApproximateCountDistinctAggregation.java | 2 +- ...ltApproximateCountDistinctAggregation.java | 2 +- .../aggregation/CountAggregation.java | 2 +- .../operator/aggregation/CountColumn.java | 2 +- .../aggregation/CountIfAggregation.java | 2 +- ...ltApproximateCountDistinctAggregation.java | 2 +- .../aggregation/ParametricAggregation.java | 3 +++ .../OptimizeMixedDistinctAggregations.java | 6 ++++-- .../spi/function/AggregationFunction.java | 7 +++++++ .../function/AggregationFunctionMetadata.java | 20 +++++++++++++++++-- 14 files changed, 53 insertions(+), 12 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 754653245323..053754bfc013 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -2416,6 +2416,10 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio builder.orderSensitive(); } + if (aggregationFunctionMetadata.returnsZeroOnEmptyInput()) { + builder.returnsZeroOnEmptyInput(); + } + if (!aggregationFunctionMetadata.getIntermediateTypes().isEmpty()) { FunctionBinding functionBinding = toFunctionBinding(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionSignature); aggregationFunctionMetadata.getIntermediateTypes().stream() diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 37c747c99071..038dd9ceae46 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -187,6 +187,7 @@ private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinit parseDescription(aggregationDefinition, outputFunction), aggregationAnnotation.decomposable(), aggregationAnnotation.isOrderSensitive(), + aggregationAnnotation.returnsZeroOnEmptyInput(), aggregationAnnotation.hidden(), aggregationDefinition.getAnnotationsByType(Deprecated.class).length > 0); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java index 1168fe061192..a4a1f579ce50 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java @@ -24,15 +24,17 @@ public class AggregationHeader private final Optional description; private final boolean decomposable; private final boolean orderSensitive; + private final boolean returnsZeroOnEmptyInput; private final boolean hidden; private final boolean deprecated; - public AggregationHeader(String name, Optional description, boolean decomposable, boolean orderSensitive, boolean hidden, boolean deprecated) + public AggregationHeader(String name, Optional description, boolean decomposable, boolean orderSensitive, boolean returnsZeroOnEmptyInput, boolean hidden, boolean deprecated) { this.name = requireNonNull(name, "name cannot be null"); this.description = requireNonNull(description, "description cannot be null"); this.decomposable = decomposable; this.orderSensitive = orderSensitive; + this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput; this.hidden = hidden; this.deprecated = deprecated; } @@ -57,6 +59,11 @@ public boolean isOrderSensitive() return orderSensitive; } + public boolean returnsZeroOnEmptyInput() + { + return returnsZeroOnEmptyInput; + } + public boolean isHidden() { return hidden; @@ -75,6 +82,7 @@ public String toString() .add("description", description) .add("decomposable", decomposable) .add("orderSensitive", orderSensitive) + .add("returnsZeroOnEmptyInput", returnsZeroOnEmptyInput) .add("hidden", hidden) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java index e28492becf3f..85df5fb579c5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -41,7 +41,7 @@ import static io.trino.util.Failures.checkCondition; import static io.trino.util.Failures.internalError; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class ApproximateCountDistinctAggregation { private static final double LOWEST_MAX_STANDARD_ERROR = 0.0040625; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java index 1171a14ea9fa..ad1b37d954b1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java @@ -24,7 +24,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class BooleanApproximateCountDistinctAggregation { private BooleanApproximateCountDistinctAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java index 4e414e328fe5..99c8a3cdf327 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java @@ -22,7 +22,7 @@ import io.trino.spi.function.SqlType; import io.trino.spi.type.StandardTypes; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class BooleanDefaultApproximateCountDistinctAggregation { // this value is ignored for boolean, but this is left here for completeness diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java index 0c4ce05a19c6..ee9cdf1d8f80 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java @@ -25,7 +25,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count") +@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true) public final class CountAggregation { private CountAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 9163dab81c84..656563122ff3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -30,7 +30,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count") +@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true) @Description("Counts the non-null values") public final class CountColumn { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java index b67a853bce3a..9518f78a4708 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java @@ -26,7 +26,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count_if") +@AggregationFunction(value = "count_if", returnsZeroOnEmptyInput = true) public final class CountIfAggregation { private CountIfAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index ee7c9e7de10e..ddecf0cce518 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -35,7 +35,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.XX_HASH_64; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class DefaultApproximateCountDistinctAggregation { private static final double DEFAULT_STANDARD_ERROR = 0.023; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index d447e88ecd35..3081d410be2b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -103,6 +103,9 @@ private static AggregationFunctionMetadata createAggregationFunctionMetadata(Agg if (details.isOrderSensitive()) { builder.orderSensitive(); } + if (details.returnsZeroOnEmptyInput()) { + builder.returnsZeroOnEmptyInput(); + } if (details.isDecomposable()) { for (AccumulatorStateDetails stateDetail : stateDetails) { builder.intermediateType(stateDetail.getSerializedType()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 26d52947d242..7be5c132ab35 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -22,6 +22,7 @@ import io.trino.execution.querystats.PlanOptimizersStatsCollector; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.type.Type; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -189,8 +190,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext intermediateTypes; - private AggregationFunctionMetadata(boolean orderSensitive, List intermediateTypes) + private AggregationFunctionMetadata(boolean orderSensitive, boolean returnsZeroOnEmptyInput, List intermediateTypes) { this.orderSensitive = orderSensitive; + this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput; this.intermediateTypes = List.copyOf(requireNonNull(intermediateTypes, "intermediateTypes is null")); } @@ -40,6 +43,11 @@ public boolean isOrderSensitive() return orderSensitive; } + public boolean returnsZeroOnEmptyInput() + { + return returnsZeroOnEmptyInput; + } + public boolean isDecomposable() { return !intermediateTypes.isEmpty(); @@ -55,6 +63,7 @@ public String toString() { return new StringJoiner(", ", AggregationFunctionMetadata.class.getSimpleName() + "[", "]") .add("orderSensitive=" + orderSensitive) + .add("returnsZeroOnEmptyInput=" + returnsZeroOnEmptyInput) .add("intermediateTypes=" + intermediateTypes) .toString(); } @@ -68,6 +77,7 @@ public static class AggregationFunctionMetadataBuilder { private boolean orderSensitive; private final List intermediateTypes = new ArrayList<>(); + private boolean returnsZeroOnEmptyInput; private AggregationFunctionMetadataBuilder() {} @@ -77,6 +87,12 @@ public AggregationFunctionMetadataBuilder orderSensitive() return this; } + public AggregationFunctionMetadataBuilder returnsZeroOnEmptyInput() + { + this.returnsZeroOnEmptyInput = true; + return this; + } + public AggregationFunctionMetadataBuilder intermediateType(Type type) { this.intermediateTypes.add(type.getTypeSignature()); @@ -91,7 +107,7 @@ public AggregationFunctionMetadataBuilder intermediateType(TypeSignature type) public AggregationFunctionMetadata build() { - return new AggregationFunctionMetadata(orderSensitive, intermediateTypes); + return new AggregationFunctionMetadata(orderSensitive, returnsZeroOnEmptyInput, intermediateTypes); } } } From 082458c9e5d13b2b075c2c5ae86844456a8a243d Mon Sep 17 00:00:00 2001 From: lukasz-stec Date: Mon, 26 Jun 2023 15:03:15 +0200 Subject: [PATCH 2/2] Re-run CI