diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index cd432878cf71..4bf05fb2072e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -602,27 +602,29 @@ protected Object visitInPredicate(InPredicate node, Object context) return null; } - Set set = inListCache.get(valueList); - - // We use the presence of the node in the map to indicate that we've already done - // the analysis below. If the value is null, it means that we can't apply the HashSet - // optimization - if (!inListCache.containsKey(valueList)) { - if (valueList.getValues().stream().allMatch(Literal.class::isInstance) && - valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) { - Set objectSet = valueList.getValues().stream().map(expression -> processWithExceptionHandling(expression, context)).collect(Collectors.toSet()); - Type type = type(node.getValue()); - set = FastutilSetHelper.toFastutilHashSet( - objectSet, - type, - plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), - plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); + if (!(value instanceof Expression)) { + Set set = inListCache.get(valueList); + + // We use the presence of the node in the map to indicate that we've already done + // the analysis below. If the value is null, it means that we can't apply the HashSet + // optimization + if (!inListCache.containsKey(valueList)) { + if (valueList.getValues().stream().allMatch(Literal.class::isInstance) && + valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) { + Set objectSet = valueList.getValues().stream().map(expression -> processWithExceptionHandling(expression, context)).collect(Collectors.toSet()); + Type type = type(node.getValue()); + set = FastutilSetHelper.toFastutilHashSet( + objectSet, + type, + plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), + plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); + } + inListCache.put(valueList, set); } - inListCache.put(valueList, set); - } - if (set != null && !(value instanceof Expression)) { - return set.contains(value); + if (set != null) { + return set.contains(value); + } } boolean hasUnresolvedValue = value instanceof Expression; @@ -633,6 +635,14 @@ protected Object visitInPredicate(InPredicate node, Object context) ResolvedFunction equalsOperator = metadata.resolveOperator(session, OperatorType.EQUAL, types(node.getValue(), valueList)); for (Expression expression : valueList.getValues()) { + if (value instanceof Expression && expression instanceof Literal) { + // skip interpreting of literal IN term since it cannot be compared + // with unresolved "value" and it cannot be simplified further + values.add(expression); + types.add(type(expression)); + continue; + } + // Use process() instead of processWithExceptionHandling() for processing in-list items. // Do not handle exceptions thrown while processing a single in-list expression, // but fail the whole in-predicate evaluation. @@ -668,11 +678,11 @@ else if (!found && result) { Type type = type(node.getValue()); List expressionValues = toExpressions(values, types); List simplifiedExpressionValues = Stream.concat( - expressionValues.stream() - .filter(expression -> isDeterministic(expression, metadata)) - .distinct(), - expressionValues.stream() - .filter((expression -> !isDeterministic(expression, metadata)))) + expressionValues.stream() + .filter(expression -> isDeterministic(expression, metadata)) + .distinct(), + expressionValues.stream() + .filter((expression -> !isDeterministic(expression, metadata)))) .collect(toImmutableList()); if (simplifiedExpressionValues.size() == 1) {