From a46d9dce7f6bea28783847fafc190a729ad34ea4 Mon Sep 17 00:00:00 2001 From: radek-starburst <94364205+radek-starburst@users.noreply.github.com> Date: Wed, 21 Jun 2023 13:53:10 +0200 Subject: [PATCH] Disallow usage of aggreagtion function as UNNEST parameter --- .../io/trino/metadata/InternalFunctionBundle.java | 15 +++++++++------ .../io/trino/sql/analyzer/StatementAnalyzer.java | 1 + .../java/io/trino/sql/analyzer/TestAnalyzer.java | 8 ++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java index 39d3fc685486..cbf995d739eb 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java @@ -138,8 +138,9 @@ public ScalarFunctionImplementation getScalarFunctionImplementation( private SpecializedSqlScalarFunction specializeScalarFunction(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { - SqlScalarFunction function = (SqlScalarFunction) getSqlFunction(functionId); - return function.specialize(boundSignature, functionDependencies); + SqlFunction function = getSqlFunction(functionId); + checkArgument(function instanceof SqlScalarFunction, "%s is not a scalar function", function.getFunctionMetadata().getSignature()); + return ((SqlScalarFunction) function).specialize(boundSignature, functionDependencies); } @Override @@ -156,8 +157,9 @@ public AggregationImplementation getAggregationImplementation(FunctionId functio private AggregationImplementation specializedAggregation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { - SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) functions.get(functionId); - return aggregationFunction.specialize(boundSignature, functionDependencies); + SqlFunction function = getSqlFunction(functionId); + checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", function.getFunctionMetadata().getSignature()); + return ((SqlAggregationFunction) function).specialize(boundSignature, functionDependencies); } @Override @@ -174,8 +176,9 @@ public WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, B private WindowFunctionSupplier specializeWindow(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { - SqlWindowFunction function = (SqlWindowFunction) functions.get(functionId); - return function.specialize(boundSignature, functionDependencies); + SqlFunction function = functions.get(functionId); + checkArgument(function instanceof SqlWindowFunction, "%s is not a window function", function.getFunctionMetadata().getSignature()); + return ((SqlWindowFunction) function).specialize(boundSignature, functionDependencies); } private SqlFunction getSqlFunction(FunctionId functionId) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index c32f3c164d86..651df71741d0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1512,6 +1512,7 @@ protected Scope visitUnnest(Unnest node, Optional scope) ImmutableList.Builder outputFields = ImmutableList.builder(); for (Expression expression : node.getExpressions()) { + verifyNoAggregateWindowOrGroupingFunctions(session, metadata, expression, "UNNEST"); List expressionOutputs = new ArrayList<>(); ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, createScope(scope)); diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index dbd8278883c5..fdf7e5119e40 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -6681,6 +6681,14 @@ public void testJsonTable() .hasMessage("line 1:15: JSON_TABLE is not yet supported"); } + @Test + public void testDisallowAggregationFunctionInUnnest() + { + assertFails("SELECT a FROM (VALUES (1), (2)) t(a), UNNEST(ARRAY[COUNT(t.a)])") + .hasErrorCode(EXPRESSION_NOT_SCALAR) + .hasMessage("line 1:46: UNNEST cannot contain aggregations, window functions or grouping operations: [COUNT(t.a)]"); + } + @BeforeClass public void setup() {