Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ public SolverReturnStatus update(BoundVariables.Builder bindings)
Type actualReturnType = ((FunctionType) actualLambdaType).getReturnType();

ImmutableList.Builder<TypeConstraintSolver> constraintsBuilder = ImmutableList.builder();
// Coercion on function type is not supported yet.
if (!appendTypeRelationshipConstraintSolver(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), false)) {
if (!appendTypeRelationshipConstraintSolver(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), allowCoercion)) {
return SolverReturnStatus.UNSOLVABLE;
}
if (!appendConstraintSolvers(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), allowCoercion)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1079,32 +1079,38 @@ else if (frame.getType() == GROUPS) {
if (expression instanceof LambdaExpression || expression instanceof BindExpression) {
argumentTypesBuilder.add(new TypeSignatureProvider(
types -> {
ExpressionAnalyzer innerExpressionAnalyzer = new ExpressionAnalyzer(
functionAndTypeResolver,
statementAnalyzerFactory,
sessionFunctions,
transactionId,
sqlFunctionProperties,
symbolTypes,
parameters,
warningCollector,
isDescribe,
outerScopeSymbolTypes);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) {
innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument));
try {
ExpressionAnalyzer innerExpressionAnalyzer = new ExpressionAnalyzer(
functionAndTypeResolver,
statementAnalyzerFactory,
sessionFunctions,
transactionId,
sqlFunctionProperties,
symbolTypes,
parameters,
warningCollector,
isDescribe,
outerScopeSymbolTypes);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) {
innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument));
}
}
Type type = innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types, ImmutableMap.of()));
if (expression instanceof LambdaExpression) {
verifyNoAggregateWindowOrGroupingFunctions(
innerExpressionAnalyzer.getResolvedFunctions(),
functionAndTypeResolver,
((LambdaExpression) expression).getBody(),
"Lambda expression");
verifyNoExternalFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionAndTypeResolver, ((LambdaExpression) expression).getBody(), "Lambda expression");
}
return type.getTypeSignature();
}
Type type = innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types, ImmutableMap.of()));
if (expression instanceof LambdaExpression) {
verifyNoAggregateWindowOrGroupingFunctions(
innerExpressionAnalyzer.getResolvedFunctions(),
functionAndTypeResolver,
((LambdaExpression) expression).getBody(),
"Lambda expression");
verifyNoExternalFunctions(innerExpressionAnalyzer.getResolvedFunctions(), functionAndTypeResolver, ((LambdaExpression) expression).getBody(), "Lambda expression");
catch (LambdaArgumentCountMismatchException e) {
// Return non-function type for SignatureBinder to skip invalid lambda function signatures
return new TypeSignature("unknown");
}
return type.getTypeSignature();
}));
}
else {
Expand Down Expand Up @@ -1175,7 +1181,22 @@ else if (arguments.size() == 1) {
}
if (argumentTypes.get(i).hasDependency()) {
FunctionType expectedFunctionType = (FunctionType) expectedType;
process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes(), resolvedLambdaArguments)));
Type actualLambdaType = process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes(), resolvedLambdaArguments)));

// Apply coercion to lambda return type if needed
if (actualLambdaType instanceof FunctionType) {
FunctionType actualFunctionType = (FunctionType) actualLambdaType;
Type actualReturnType = actualFunctionType.getReturnType();
Type expectedReturnType = expectedFunctionType.getReturnType();

if (!actualReturnType.equals(expectedReturnType) && functionAndTypeResolver.canCoerce(actualReturnType, expectedReturnType)) {
if (expression instanceof LambdaExpression) {
LambdaExpression lambda = (LambdaExpression) expression;
addOrReplaceExpressionCoercion(lambda.getBody(), actualReturnType, expectedReturnType);
setExpressionType(expression, expectedFunctionType);
}
}
}
}
else {
Type actualType = functionAndTypeResolver.getType(argumentTypes.get(i).getTypeSignature());
Expand Down Expand Up @@ -1529,8 +1550,7 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC
List<LambdaArgumentDeclaration> lambdaArguments = node.getArguments();

if (types.size() != lambdaArguments.size()) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node,
format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size()));
throw new LambdaArgumentCountMismatchException(node, format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size()));
}

ImmutableList.Builder<Field> fields = ImmutableList.builder();
Expand Down Expand Up @@ -2199,4 +2219,13 @@ public static boolean isNumericType(Type type)
type.equals(REAL) ||
type instanceof DecimalType;
}

private static class LambdaArgumentCountMismatchException
extends SemanticException
{
public LambdaArgumentCountMismatchException(Node node, String message)
{
super(INVALID_PARAMETER_USAGE, node, message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,68 @@ public void testSort()
assertCachedInstanceHasBoundedRetainedSize("ARRAY_SORT(ARRAY[2, 3, 4, 1])");
}

@Test
public void testArraySortLambdaReturnTypeCoercion()
{
assertFunction(
"ARRAY_SORT(ARRAY[2, 3, 1], " +
"(x, y) -> IF(x < y, -1, IF(x = y, 0, 1)))",
new ArrayType(INTEGER),
ImmutableList.of(1, 2, 3));

assertFunction(
"ARRAY_SORT(ARRAY[3, 1, 2], " +
"(x, y) -> CASE WHEN x < y THEN -1 WHEN x = y THEN 0 ELSE 1 END)",
new ArrayType(INTEGER),
ImmutableList.of(1, 2, 3));

assertFunction(
"ARRAY_SORT(ARRAY[5, 3, 1], " +
"(x, y) -> SIGN(x - y))",
new ArrayType(INTEGER),
ImmutableList.of(1, 3, 5));

assertFunction(
"ARRAY_SORT(ARRAY[3, null, 1, null, 2], " +
"(x, y) -> CASE " +
"WHEN x IS NULL AND y IS NULL THEN 0 " +
"WHEN x IS NULL THEN -1 " +
"WHEN y IS NULL THEN 1 " +
"WHEN x < y THEN -1 " +
"WHEN x = y THEN 0 " +
"ELSE 1 END)",
new ArrayType(INTEGER),
asList(null, null, 1, 2, 3));

assertFunction(
"ARRAY_SORT(ARRAY['apple', 'pie', 'banana', 'a'], " +
"(x, y) -> SIGN(LENGTH(x) - LENGTH(y)))",
new ArrayType(createVarcharType(6)),
ImmutableList.of("a", "pie", "apple", "banana"));

assertFunction(
"ARRAY_SORT(ARRAY[2.7E0, 1.2E0, 3.9E0, 2.1E0], " +
"(x, y) -> SIGN(CAST(FLOOR(x) AS INTEGER) - CAST(FLOOR(y) AS INTEGER)))",
new ArrayType(DOUBLE),
ImmutableList.of(1.2, 2.7, 2.1, 3.9));

assertFunction(
"ARRAY_SORT(ARRAY[5, 10, 3, 15, 7], " +
"(x, y) -> CASE " +
"WHEN x % 5 = 0 AND y % 5 = 0 THEN SIGN(x - y) " +
"WHEN x % 5 = 0 THEN -1 " +
"WHEN y % 5 = 0 THEN 1 " +
"ELSE SIGN(x - y) END)",
new ArrayType(INTEGER),
ImmutableList.of(5, 10, 15, 3, 7));

assertFunction(
"ARRAY_SORT(ARRAY[10, 0, 5, -5], " +
"(x, y) -> IF(x = 0, 1, IF(y = 0, -1, SIGN(x - y))))",
new ArrayType(INTEGER),
ImmutableList.of(-5, 5, 10, 0));
}

@Test
public void testReverse()
{
Expand Down
Loading