diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java index bd383aa6ef87b..23001664f8039 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java @@ -772,8 +772,7 @@ public SolverReturnStatus update(BoundVariables.Builder bindings) Type actualReturnType = ((FunctionType) actualLambdaType).getReturnType(); ImmutableList.Builder constraintsBuilder = ImmutableList.builder(); - // Coercion on function type is not supported yet. - if (!appendTypeRelationshipConstraintSolver(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), false)) { + if (!appendTypeRelationshipConstraintSolver(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), allowCoercion)) { return SolverReturnStatus.UNSOLVABLE; } if (!appendConstraintSolvers(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), allowCoercion)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index e2739aca9a69b..00ac1482320bf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -1079,32 +1079,38 @@ else if (frame.getType() == GROUPS) { if (expression instanceof LambdaExpression || expression instanceof BindExpression) { argumentTypesBuilder.add(new TypeSignatureProvider( types -> { - ExpressionAnalyzer innerExpressionAnalyzer = new ExpressionAnalyzer( - functionAndTypeResolver, - statementAnalyzerFactory, - sessionFunctions, - transactionId, - sqlFunctionProperties, - symbolTypes, - parameters, - warningCollector, - isDescribe, - outerScopeSymbolTypes); - if (context.getContext().isInLambda()) { - for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) { - innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument)); + try { + ExpressionAnalyzer innerExpressionAnalyzer = new ExpressionAnalyzer( + functionAndTypeResolver, + statementAnalyzerFactory, + sessionFunctions, + transactionId, + sqlFunctionProperties, + symbolTypes, + parameters, + warningCollector, + isDescribe, + outerScopeSymbolTypes); + if (context.getContext().isInLambda()) { + for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) { + innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument)); + } + } + Type type = innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types, ImmutableMap.of())); + if (expression instanceof LambdaExpression) { + verifyNoAggregateWindowOrGroupingFunctions( + innerExpressionAnalyzer.getResolvedFunctions(), + functionAndTypeResolver, + ((LambdaExpression) expression).getBody(), + "Lambda expression"); + verifyNoExternalFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionAndTypeResolver, ((LambdaExpression) expression).getBody(), "Lambda expression"); } + return type.getTypeSignature(); } - Type type = innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types, ImmutableMap.of())); - if (expression instanceof LambdaExpression) { - verifyNoAggregateWindowOrGroupingFunctions( - innerExpressionAnalyzer.getResolvedFunctions(), - functionAndTypeResolver, - ((LambdaExpression) expression).getBody(), - "Lambda expression"); - verifyNoExternalFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionAndTypeResolver, ((LambdaExpression) expression).getBody(), "Lambda expression"); + catch (LambdaArgumentCountMismatchException e) { + // Return non-function type for SignatureBinder to skip invalid lambda function signatures + return new TypeSignature("unknown"); } - return type.getTypeSignature(); })); } else { @@ -1175,7 +1181,22 @@ else if (arguments.size() == 1) { } if (argumentTypes.get(i).hasDependency()) { FunctionType expectedFunctionType = (FunctionType) expectedType; - process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes(), resolvedLambdaArguments))); + Type actualLambdaType = process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes(), resolvedLambdaArguments))); + + // Apply coercion to lambda return type if needed + if (actualLambdaType instanceof FunctionType) { + FunctionType actualFunctionType = (FunctionType) actualLambdaType; + Type actualReturnType = actualFunctionType.getReturnType(); + Type expectedReturnType = expectedFunctionType.getReturnType(); + + if (!actualReturnType.equals(expectedReturnType) && functionAndTypeResolver.canCoerce(actualReturnType, expectedReturnType)) { + if (expression instanceof LambdaExpression) { + LambdaExpression lambda = (LambdaExpression) expression; + addOrReplaceExpressionCoercion(lambda.getBody(), actualReturnType, expectedReturnType); + setExpressionType(expression, expectedFunctionType); + } + } + } } else { Type actualType = functionAndTypeResolver.getType(argumentTypes.get(i).getTypeSignature()); @@ -1529,8 +1550,7 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC List lambdaArguments = node.getArguments(); if (types.size() != lambdaArguments.size()) { - throw new SemanticException(INVALID_PARAMETER_USAGE, node, - format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size())); + throw new LambdaArgumentCountMismatchException(node, format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size())); } ImmutableList.Builder fields = ImmutableList.builder(); @@ -2199,4 +2219,13 @@ public static boolean isNumericType(Type type) type.equals(REAL) || type instanceof DecimalType; } + + private static class LambdaArgumentCountMismatchException + extends SemanticException + { + public LambdaArgumentCountMismatchException(Node node, String message) + { + super(INVALID_PARAMETER_USAGE, node, message); + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main-base/src/test/java/com/facebook/presto/type/TestArrayOperators.java index 5d37e94329b63..5c05cf4b07e6a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -1154,6 +1154,68 @@ public void testSort() assertCachedInstanceHasBoundedRetainedSize("ARRAY_SORT(ARRAY[2, 3, 4, 1])"); } + @Test + public void testArraySortLambdaReturnTypeCoercion() + { + assertFunction( + "ARRAY_SORT(ARRAY[2, 3, 1], " + + "(x, y) -> IF(x < y, -1, IF(x = y, 0, 1)))", + new ArrayType(INTEGER), + ImmutableList.of(1, 2, 3)); + + assertFunction( + "ARRAY_SORT(ARRAY[3, 1, 2], " + + "(x, y) -> CASE WHEN x < y THEN -1 WHEN x = y THEN 0 ELSE 1 END)", + new ArrayType(INTEGER), + ImmutableList.of(1, 2, 3)); + + assertFunction( + "ARRAY_SORT(ARRAY[5, 3, 1], " + + "(x, y) -> SIGN(x - y))", + new ArrayType(INTEGER), + ImmutableList.of(1, 3, 5)); + + assertFunction( + "ARRAY_SORT(ARRAY[3, null, 1, null, 2], " + + "(x, y) -> CASE " + + "WHEN x IS NULL AND y IS NULL THEN 0 " + + "WHEN x IS NULL THEN -1 " + + "WHEN y IS NULL THEN 1 " + + "WHEN x < y THEN -1 " + + "WHEN x = y THEN 0 " + + "ELSE 1 END)", + new ArrayType(INTEGER), + asList(null, null, 1, 2, 3)); + + assertFunction( + "ARRAY_SORT(ARRAY['apple', 'pie', 'banana', 'a'], " + + "(x, y) -> SIGN(LENGTH(x) - LENGTH(y)))", + new ArrayType(createVarcharType(6)), + ImmutableList.of("a", "pie", "apple", "banana")); + + assertFunction( + "ARRAY_SORT(ARRAY[2.7E0, 1.2E0, 3.9E0, 2.1E0], " + + "(x, y) -> SIGN(CAST(FLOOR(x) AS INTEGER) - CAST(FLOOR(y) AS INTEGER)))", + new ArrayType(DOUBLE), + ImmutableList.of(1.2, 2.7, 2.1, 3.9)); + + assertFunction( + "ARRAY_SORT(ARRAY[5, 10, 3, 15, 7], " + + "(x, y) -> CASE " + + "WHEN x % 5 = 0 AND y % 5 = 0 THEN SIGN(x - y) " + + "WHEN x % 5 = 0 THEN -1 " + + "WHEN y % 5 = 0 THEN 1 " + + "ELSE SIGN(x - y) END)", + new ArrayType(INTEGER), + ImmutableList.of(5, 10, 15, 3, 7)); + + assertFunction( + "ARRAY_SORT(ARRAY[10, 0, 5, -5], " + + "(x, y) -> IF(x = 0, 1, IF(y = 0, -1, SIGN(x - y))))", + new ArrayType(INTEGER), + ImmutableList.of(-5, 5, 10, 0)); + } + @Test public void testReverse() {