diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 07ceedca2306..2cb1b7c77b2b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Streams; import io.trino.Session; @@ -135,6 +136,7 @@ import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -2218,19 +2220,16 @@ protected Type visitInPredicate(InPredicate node, StackableAstVisitorContextbuilder().add(value).addAll(inListExpression.getValues()).build()); + setExpressionType(inListExpression, type); } else if (valueList instanceof SubqueryExpression) { subqueryInPredicates.add(NodeRef.of(node)); - analyzePredicateWithSubquery(node, declaredValueType, (SubqueryExpression) valueList, context); + analyzePredicateWithSubquery(node, process(value, context), (SubqueryExpression) valueList, context); } else { throw new IllegalArgumentException("Unexpected value list type for InPredicate: " + node.getValueList().getClass().getName()); @@ -2239,15 +2238,6 @@ else if (valueList instanceof SubqueryExpression) { return setExpressionType(node, BOOLEAN); } - @Override - protected Type visitInListExpression(InListExpression node, StackableAstVisitorContext context) - { - Type type = coerceToSingleType(context, "All IN list values must be the same type: %s", node.getValues()); - - setExpressionType(node, type); - return type; // TODO: this really should a be relation type - } - @Override protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext context) { @@ -2568,22 +2558,33 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Str { // determine super type Type superType = UNKNOWN; + + // Use LinkedHashMultimap to preserve order in which expressions are analyzed within IN list + Multimap> typeExpressions = LinkedHashMultimap.create(); for (Expression expression : expressions) { - Optional newSuperType = typeCoercion.getCommonSuperType(superType, process(expression, context)); + // We need to wrap as NodeRef since LinkedHashMultimap does not allow duplicated values + typeExpressions.put(process(expression, context), NodeRef.of(expression)); + } + + Set types = typeExpressions.keySet(); + + for (Type type : types) { + Optional newSuperType = typeCoercion.getCommonSuperType(superType, type); if (newSuperType.isEmpty()) { - throw semanticException(TYPE_MISMATCH, expression, message, superType); + throw semanticException(TYPE_MISMATCH, Iterables.get(typeExpressions.get(type), 0).getNode(), message, superType); } superType = newSuperType.get(); } // verify all expressions can be coerced to the superType - for (Expression expression : expressions) { - Type type = process(expression, context); + for (Type type : types) { + Collection> coercionCandidates = typeExpressions.get(type); + if (!type.equals(superType)) { if (!typeCoercion.canCoerce(type, superType)) { - throw semanticException(TYPE_MISMATCH, expression, message, superType); + throw semanticException(TYPE_MISMATCH, Iterables.get(coercionCandidates, 0).getNode(), message, superType); } - addOrReplaceExpressionCoercion(expression, type, superType); + addOrReplaceExpressionsCoercion(coercionCandidates, type, superType); } } @@ -2592,13 +2593,17 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Str private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType) { - NodeRef ref = NodeRef.of(expression); - expressionCoercions.put(ref, superType); + addOrReplaceExpressionsCoercion(ImmutableList.of(NodeRef.of(expression)), type, superType); + } + + private void addOrReplaceExpressionsCoercion(Collection> expressions, Type type, Type superType) + { + expressions.forEach(expression -> expressionCoercions.put(expression, superType)); if (typeCoercion.isTypeOnlyCoercion(type, superType)) { - typeOnlyCoercions.add(ref); + typeOnlyCoercions.addAll(expressions); } else { - typeOnlyCoercions.remove(ref); + typeOnlyCoercions.removeAll(expressions); } } }