Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Streams;
import io.trino.Session;
Expand Down Expand Up @@ -135,6 +136,7 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -2218,19 +2220,16 @@ protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext<Con
});
}

Type declaredValueType = process(value, context);

if (valueList instanceof InListExpression) {
process(valueList, context);
InListExpression inListExpression = (InListExpression) valueList;

coerceToSingleType(context,
Type type = coerceToSingleType(context,
"IN value and list items must be the same type: %s",
ImmutableList.<Expression>builder().add(value).addAll(inListExpression.getValues()).build());
setExpressionType(inListExpression, type);
}
else if (valueList instanceof SubqueryExpression) {
subqueryInPredicates.add(NodeRef.of(node));
analyzePredicateWithSubquery(node, declaredValueType, (SubqueryExpression) valueList, context);
analyzePredicateWithSubquery(node, process(value, context), (SubqueryExpression) valueList, context);
}
else {
throw new IllegalArgumentException("Unexpected value list type for InPredicate: " + node.getValueList().getClass().getName());
Expand All @@ -2239,15 +2238,6 @@ else if (valueList instanceof SubqueryExpression) {
return setExpressionType(node, BOOLEAN);
}

@Override
protected Type visitInListExpression(InListExpression node, StackableAstVisitorContext<Context> context)
{
Type type = coerceToSingleType(context, "All IN list values must be the same type: %s", node.getValues());

setExpressionType(node, type);
return type; // TODO: this really should a be relation type
}

@Override
protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext<Context> context)
{
Expand Down Expand Up @@ -2568,22 +2558,33 @@ private Type coerceToSingleType(StackableAstVisitorContext<Context> context, Str
{
// determine super type
Type superType = UNKNOWN;

// Use LinkedHashMultimap to preserve order in which expressions are analyzed within IN list
Multimap<Type, NodeRef<Expression>> typeExpressions = LinkedHashMultimap.create();
for (Expression expression : expressions) {
Optional<Type> newSuperType = typeCoercion.getCommonSuperType(superType, process(expression, context));
// We need to wrap as NodeRef since LinkedHashMultimap does not allow duplicated values
typeExpressions.put(process(expression, context), NodeRef.of(expression));
}

Set<Type> types = typeExpressions.keySet();

for (Type type : types) {
Optional<Type> newSuperType = typeCoercion.getCommonSuperType(superType, type);
if (newSuperType.isEmpty()) {
throw semanticException(TYPE_MISMATCH, expression, message, superType);
throw semanticException(TYPE_MISMATCH, Iterables.get(typeExpressions.get(type), 0).getNode(), message, superType);
}
superType = newSuperType.get();
}

// verify all expressions can be coerced to the superType
for (Expression expression : expressions) {
Type type = process(expression, context);
for (Type type : types) {
Collection<NodeRef<Expression>> coercionCandidates = typeExpressions.get(type);

if (!type.equals(superType)) {
if (!typeCoercion.canCoerce(type, superType)) {
throw semanticException(TYPE_MISMATCH, expression, message, superType);
throw semanticException(TYPE_MISMATCH, Iterables.get(coercionCandidates, 0).getNode(), message, superType);
}
addOrReplaceExpressionCoercion(expression, type, superType);
addOrReplaceExpressionsCoercion(coercionCandidates, type, superType);
}
}

Expand All @@ -2592,13 +2593,17 @@ private Type coerceToSingleType(StackableAstVisitorContext<Context> context, Str

private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType)
{
NodeRef<Expression> ref = NodeRef.of(expression);
expressionCoercions.put(ref, superType);
addOrReplaceExpressionsCoercion(ImmutableList.of(NodeRef.of(expression)), type, superType);
}

private void addOrReplaceExpressionsCoercion(Collection<NodeRef<Expression>> expressions, Type type, Type superType)
{
expressions.forEach(expression -> expressionCoercions.put(expression, superType));
if (typeCoercion.isTypeOnlyCoercion(type, superType)) {
typeOnlyCoercions.add(ref);
typeOnlyCoercions.addAll(expressions);
}
else {
typeOnlyCoercions.remove(ref);
typeOnlyCoercions.removeAll(expressions);
}
}
}
Expand Down