diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 754653245323..053754bfc013 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -2416,6 +2416,10 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio builder.orderSensitive(); } + if (aggregationFunctionMetadata.returnsZeroOnEmptyInput()) { + builder.returnsZeroOnEmptyInput(); + } + if (!aggregationFunctionMetadata.getIntermediateTypes().isEmpty()) { FunctionBinding functionBinding = toFunctionBinding(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionSignature); aggregationFunctionMetadata.getIntermediateTypes().stream() diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 37c747c99071..038dd9ceae46 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -187,6 +187,7 @@ private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinit parseDescription(aggregationDefinition, outputFunction), aggregationAnnotation.decomposable(), aggregationAnnotation.isOrderSensitive(), + aggregationAnnotation.returnsZeroOnEmptyInput(), aggregationAnnotation.hidden(), aggregationDefinition.getAnnotationsByType(Deprecated.class).length > 0); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java index 1168fe061192..a4a1f579ce50 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationHeader.java @@ -24,15 +24,17 @@ public class AggregationHeader private final Optional description; private final boolean decomposable; private final boolean orderSensitive; + private final boolean returnsZeroOnEmptyInput; private final boolean hidden; private final boolean deprecated; - public AggregationHeader(String name, Optional description, boolean decomposable, boolean orderSensitive, boolean hidden, boolean deprecated) + public AggregationHeader(String name, Optional description, boolean decomposable, boolean orderSensitive, boolean returnsZeroOnEmptyInput, boolean hidden, boolean deprecated) { this.name = requireNonNull(name, "name cannot be null"); this.description = requireNonNull(description, "description cannot be null"); this.decomposable = decomposable; this.orderSensitive = orderSensitive; + this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput; this.hidden = hidden; this.deprecated = deprecated; } @@ -57,6 +59,11 @@ public boolean isOrderSensitive() return orderSensitive; } + public boolean returnsZeroOnEmptyInput() + { + return returnsZeroOnEmptyInput; + } + public boolean isHidden() { return hidden; @@ -75,6 +82,7 @@ public String toString() .add("description", description) .add("decomposable", decomposable) .add("orderSensitive", orderSensitive) + .add("returnsZeroOnEmptyInput", returnsZeroOnEmptyInput) .add("hidden", hidden) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java index e28492becf3f..85df5fb579c5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -41,7 +41,7 @@ import static io.trino.util.Failures.checkCondition; import static io.trino.util.Failures.internalError; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class ApproximateCountDistinctAggregation { private static final double LOWEST_MAX_STANDARD_ERROR = 0.0040625; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java index 1171a14ea9fa..ad1b37d954b1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanApproximateCountDistinctAggregation.java @@ -24,7 +24,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class BooleanApproximateCountDistinctAggregation { private BooleanApproximateCountDistinctAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java index 4e414e328fe5..99c8a3cdf327 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/BooleanDefaultApproximateCountDistinctAggregation.java @@ -22,7 +22,7 @@ import io.trino.spi.function.SqlType; import io.trino.spi.type.StandardTypes; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class BooleanDefaultApproximateCountDistinctAggregation { // this value is ignored for boolean, but this is left here for completeness diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java index 0c4ce05a19c6..ee9cdf1d8f80 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java @@ -25,7 +25,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count") +@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true) public final class CountAggregation { private CountAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 9163dab81c84..656563122ff3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -30,7 +30,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count") +@AggregationFunction(value = "count", returnsZeroOnEmptyInput = true) @Description("Counts the non-null values") public final class CountColumn { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java index b67a853bce3a..9518f78a4708 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java @@ -26,7 +26,7 @@ import static io.trino.spi.type.BigintType.BIGINT; -@AggregationFunction("count_if") +@AggregationFunction(value = "count_if", returnsZeroOnEmptyInput = true) public final class CountIfAggregation { private CountIfAggregation() {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index ee7c9e7de10e..ddecf0cce518 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -35,7 +35,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.XX_HASH_64; -@AggregationFunction("approx_distinct") +@AggregationFunction(value = "approx_distinct", returnsZeroOnEmptyInput = true) public final class DefaultApproximateCountDistinctAggregation { private static final double DEFAULT_STANDARD_ERROR = 0.023; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index d447e88ecd35..3081d410be2b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -103,6 +103,9 @@ private static AggregationFunctionMetadata createAggregationFunctionMetadata(Agg if (details.isOrderSensitive()) { builder.orderSensitive(); } + if (details.returnsZeroOnEmptyInput()) { + builder.returnsZeroOnEmptyInput(); + } if (details.isDecomposable()) { for (AccumulatorStateDetails stateDetail : stateDetails) { builder.intermediateType(stateDetail.getSerializedType()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 26d52947d242..7be5c132ab35 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -22,6 +22,7 @@ import io.trino.execution.querystats.PlanOptimizersStatsCollector; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.type.Type; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -189,8 +190,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext intermediateTypes; - private AggregationFunctionMetadata(boolean orderSensitive, List intermediateTypes) + private AggregationFunctionMetadata(boolean orderSensitive, boolean returnsZeroOnEmptyInput, List intermediateTypes) { this.orderSensitive = orderSensitive; + this.returnsZeroOnEmptyInput = returnsZeroOnEmptyInput; this.intermediateTypes = List.copyOf(requireNonNull(intermediateTypes, "intermediateTypes is null")); } @@ -40,6 +43,11 @@ public boolean isOrderSensitive() return orderSensitive; } + public boolean returnsZeroOnEmptyInput() + { + return returnsZeroOnEmptyInput; + } + public boolean isDecomposable() { return !intermediateTypes.isEmpty(); @@ -55,6 +63,7 @@ public String toString() { return new StringJoiner(", ", AggregationFunctionMetadata.class.getSimpleName() + "[", "]") .add("orderSensitive=" + orderSensitive) + .add("returnsZeroOnEmptyInput=" + returnsZeroOnEmptyInput) .add("intermediateTypes=" + intermediateTypes) .toString(); } @@ -68,6 +77,7 @@ public static class AggregationFunctionMetadataBuilder { private boolean orderSensitive; private final List intermediateTypes = new ArrayList<>(); + private boolean returnsZeroOnEmptyInput; private AggregationFunctionMetadataBuilder() {} @@ -77,6 +87,12 @@ public AggregationFunctionMetadataBuilder orderSensitive() return this; } + public AggregationFunctionMetadataBuilder returnsZeroOnEmptyInput() + { + this.returnsZeroOnEmptyInput = true; + return this; + } + public AggregationFunctionMetadataBuilder intermediateType(Type type) { this.intermediateTypes.add(type.getTypeSignature()); @@ -91,7 +107,7 @@ public AggregationFunctionMetadataBuilder intermediateType(TypeSignature type) public AggregationFunctionMetadata build() { - return new AggregationFunctionMetadata(orderSensitive, intermediateTypes); + return new AggregationFunctionMetadata(orderSensitive, returnsZeroOnEmptyInput, intermediateTypes); } } }