diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 8f6a3168afdc..fcd3bbab54b7 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -751,6 +751,17 @@ public SessionBuilder setSystemProperties(Map systemProperties) return this; } + /** + * Sets catalog session properties, discarding any catalog properties previously set. + */ + public SessionBuilder setCatalogProperties(Map> catalogProperties) + { + requireNonNull(catalogProperties, "catalogProperties is null"); + this.catalogSessionProperties.clear(); + this.catalogSessionProperties.putAll(catalogProperties); + return this; + } + /** * Sets a catalog property for the session. The property name and value must * only contain characters from US-ASCII and must not be for '='. diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 07ceedca2306..5565d55c1f5d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -13,10 +13,12 @@ */ package io.trino.sql.analyzer; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Streams; import io.trino.Session; @@ -150,6 +152,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.spi.StandardErrorCode.AMBIGUOUS_NAME; @@ -297,7 +300,7 @@ public class ExpressionAnalyzer private final Function getResolvedWindow; private final List sourceFields = new ArrayList<>(); - private ExpressionAnalyzer( + public ExpressionAnalyzer( PlannerContext plannerContext, AccessControl accessControl, StatementAnalyzerFactory statementAnalyzerFactory, @@ -2218,19 +2221,16 @@ protected Type visitInPredicate(InPredicate node, StackableAstVisitorContextbuilder().add(value).addAll(inListExpression.getValues()).build()); + setExpressionType(inListExpression, type); } else if (valueList instanceof SubqueryExpression) { subqueryInPredicates.add(NodeRef.of(node)); - analyzePredicateWithSubquery(node, declaredValueType, (SubqueryExpression) valueList, context); + analyzePredicateWithSubquery(node, process(value, context), (SubqueryExpression) valueList, context); } else { throw new IllegalArgumentException("Unexpected value list type for InPredicate: " + node.getValueList().getClass().getName()); @@ -2239,15 +2239,6 @@ else if (valueList instanceof SubqueryExpression) { return setExpressionType(node, BOOLEAN); } - @Override - protected Type visitInListExpression(InListExpression node, StackableAstVisitorContext context) - { - Type type = coerceToSingleType(context, "All IN list values must be the same type: %s", node.getValues()); - - setExpressionType(node, type); - return type; // TODO: this really should a be relation type - } - @Override protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext context) { @@ -2568,22 +2559,32 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Str { // determine super type Type superType = UNKNOWN; + + ListMultimap typeExpressions = ArrayListMultimap.create(); for (Expression expression : expressions) { - Optional newSuperType = typeCoercion.getCommonSuperType(superType, process(expression, context)); + typeExpressions.put(process(expression, context), expression); + } + + // We need an explicit copy to avoid ConcurrentModificationException + Set types = typeExpressions.keySet(); + + for (Type type : types) { + Optional newSuperType = typeCoercion.getCommonSuperType(superType, type); if (newSuperType.isEmpty()) { - throw semanticException(TYPE_MISMATCH, expression, message, superType); + throw semanticException(TYPE_MISMATCH, typeExpressions.get(type).get(0), message, superType); } superType = newSuperType.get(); } // verify all expressions can be coerced to the superType - for (Expression expression : expressions) { - Type type = process(expression, context); + for (Type type : types) { + List coercionCandidates = typeExpressions.get(type); + if (!type.equals(superType)) { if (!typeCoercion.canCoerce(type, superType)) { - throw semanticException(TYPE_MISMATCH, expression, message, superType); + throw semanticException(TYPE_MISMATCH, coercionCandidates.get(0), message, superType); } - addOrReplaceExpressionCoercion(expression, type, superType); + addOrReplaceExpressionsCoercion(coercionCandidates, type, superType); } } @@ -2592,13 +2593,20 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Str private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType) { - NodeRef ref = NodeRef.of(expression); - expressionCoercions.put(ref, superType); + addOrReplaceExpressionsCoercion(List.of(expression), type, superType); + } + + private void addOrReplaceExpressionsCoercion(List expressions, Type type, Type superType) + { + Map, Type> expressionRefTypes = expressions.stream() + .collect(toImmutableMap(NodeRef::of, expression -> superType)); + + expressionCoercions.putAll(expressionRefTypes); if (typeCoercion.isTypeOnlyCoercion(type, superType)) { - typeOnlyCoercions.add(ref); + typeOnlyCoercions.addAll(expressionRefTypes.keySet()); } else { - typeOnlyCoercions.remove(ref); + expressionRefTypes.keySet().forEach(typeOnlyCoercions::remove); } } } @@ -2773,8 +2781,13 @@ public static ExpressionAnalysis analyzeExpressions( WarningCollector warningCollector, QueryType queryType) { - Analysis analysis = new Analysis(null, parameters, queryType); - ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, types, warningCollector); + return analyzeExpressions( + new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, new Analysis(null, parameters, queryType), session, types, warningCollector), + expressions); + } + + public static ExpressionAnalysis analyzeExpressions(ExpressionAnalyzer analyzer, Iterable expressions) + { for (Expression expression : expressions) { analyzer.analyze( expression, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index b7484064d4ca..d4512878e120 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -634,6 +634,14 @@ protected Object visitInPredicate(InPredicate node, Object context) ResolvedFunction equalsOperator = metadata.resolveOperator(session, OperatorType.EQUAL, types(node.getValue(), valueList)); for (Expression expression : valueList.getValues()) { + if (value instanceof Expression && expression instanceof Literal) { + // skip interpreting of literal IN term since it cannot be compared + // with unresolved "value" and it cannot be simplified further + values.add(expression); + types.add(type(expression)); + continue; + } + // Use process() instead of processWithExceptionHandling() for processing in-list items. // Do not handle exceptions thrown while processing a single in-list expression, // but fail the whole in-predicate evaluation. @@ -680,7 +688,16 @@ else if (!found && result) { return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toExpression(value, type), simplifiedExpressionValues.get(0)); } - return new InPredicate(toExpression(value, type), new InListExpression(simplifiedExpressionValues)); + Expression simplifiedValue = toExpression(value, type); + Expression simplifiedValueList = new InListExpression(simplifiedExpressionValues); + if (simplifiedValueList.equals(node.getValueList()) && simplifiedValue.equals(node.getValue())) { + // Do not create a new instance of InPredicate expression if it would be same as original expression. + // Creating a new instance of InPredicate would cause expression type cache miss, which + // is using node reference as a cache key. + return node; + } + + return new InPredicate(simplifiedValue, simplifiedValueList); } if (hasNullValue) { return null; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java index 4e9dd5a63c33..4abb26ef87af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TypeAnalyzer.java @@ -13,23 +13,34 @@ */ package io.trino.sql.planner; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; +import io.trino.collect.cache.NonEvictableCache; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AnalyzePropertyManager; import io.trino.metadata.TablePropertyManager; import io.trino.security.AllowAllAccessControl; +import io.trino.spi.QueryId; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.ExpressionAnalyzer; import io.trino.sql.analyzer.StatementAnalyzerFactory; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import javax.inject.Inject; +import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; import static io.trino.sql.analyzer.QueryType.OTHERS; import static io.trino.sql.analyzer.StatementAnalyzerFactory.createTestingStatementAnalyzerFactory; @@ -45,6 +56,14 @@ public class TypeAnalyzer private final PlannerContext plannerContext; private final StatementAnalyzerFactory statementAnalyzerFactory; + private final NonEvictableCache typeAnalyzersCache = buildNonEvictableCache( + CacheBuilder.newBuilder() + // Try to evict queries cache as soon as possible to keep cache relatively small + .expireAfterAccess(15, TimeUnit.SECONDS) + .maximumSize(256) + .softValues() + .recordStats()); + @Inject public TypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory statementAnalyzerFactory) { @@ -54,17 +73,13 @@ public TypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory stat public Map, Type> getTypes(Session session, TypeProvider inputTypes, Iterable expressions) { - return analyzeExpressions( - session, - plannerContext, - statementAnalyzerFactory, - new AllowAllAccessControl(), - inputTypes, - expressions, - ImmutableMap.of(), - WarningCollector.NOOP, - OTHERS) - .getExpressionTypes(); + try { + return typeAnalyzersCache.get(session.getQueryId(), () -> new QueryScopedCachedTypeAnalyzer(plannerContext, statementAnalyzerFactory)) + .getTypes(session, inputTypes, ImmutableList.copyOf(expressions)); + } + catch (ExecutionException e) { + throw new RuntimeException(e); + } } public Map, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) @@ -87,4 +102,59 @@ public static TypeAnalyzer createTestingTypeAnalyzer(PlannerContext plannerConte new TablePropertyManager(), new AnalyzePropertyManager())); } + + private static class QueryScopedCachedTypeAnalyzer + { + private final Cache, Type> typesCache = buildNonEvictableCache(CacheBuilder.newBuilder()); + private PlannerContext plannerContext; + private StatementAnalyzerFactory statementAnalyzerFactory; + + private QueryScopedCachedTypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory statementAnalyzerFactory) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null"); + } + + private Map, Type> getTypes(Session session, TypeProvider inputTypes, List expressions) + { + List> expressionsToResolve = collectExpressions(expressions); + Map, Type> cachedTypes = typesCache.getAllPresent(expressionsToResolve); + + // All expressions were resolved from cache + if (cachedTypes.size() == expressionsToResolve.size()) { + return cachedTypes; + } + + Map, Type> resolvedTypes = analyzeExpressions(createExpressionAnalyzer(session, plannerContext, statementAnalyzerFactory, inputTypes), expressions) + .getExpressionTypes(); + + typesCache.putAll(resolvedTypes); + return resolvedTypes; + } + + private static ExpressionAnalyzer createExpressionAnalyzer(Session session, + PlannerContext plannerContext, + StatementAnalyzerFactory statementAnalyzerFactory, + TypeProvider types) + { + return new ExpressionAnalyzer(plannerContext, new AllowAllAccessControl(), statementAnalyzerFactory, new Analysis(null, ImmutableMap.of(), OTHERS), session, types, WarningCollector.NOOP); + } + + private static ImmutableList> collectExpressions(Iterable expressions) + { + ImmutableList.Builder> builder = ImmutableList.builder(); + + for (Node expression : expressions) { + if (expression instanceof Expression) { + builder.add(NodeRef.of((Expression) expression)); + } + + if (!expression.getChildren().isEmpty()) { + builder.addAll(collectExpressions(expression.getChildren())); + } + } + + return builder.build(); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/SimplePlanRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/SimplePlanRewriter.java index 80804113b91b..5432ca778512 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/SimplePlanRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/SimplePlanRewriter.java @@ -13,10 +13,9 @@ */ package io.trino.sql.planner.plan; -import java.util.List; +import com.google.common.collect.ImmutableList; import static com.google.common.base.Verify.verifyNotNull; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; public abstract class SimplePlanRewriter @@ -69,11 +68,9 @@ public PlanNode defaultRewrite(PlanNode node) */ public PlanNode defaultRewrite(PlanNode node, C context) { - List children = node.getSources().stream() - .map(child -> rewrite(child, context)) - .collect(toImmutableList()); - - return replaceChildren(node, children); + ImmutableList.Builder children = ImmutableList.builderWithExpectedSize(node.getSources().size()); + node.getSources().forEach(source -> children.add(rewrite(source, context))); + return replaceChildren(node, children.build()); } /** diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index e434af5d1978..6a01d8d3f5cb 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -55,6 +55,7 @@ import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.Lifespan; import io.trino.execution.NodeTaskMap; +import io.trino.execution.QueryIdGenerator; import io.trino.execution.QueryManagerConfig; import io.trino.execution.QueryPreparer; import io.trino.execution.QueryPreparer.PreparedQuery; @@ -230,6 +231,8 @@ public class LocalQueryRunner implements QueryRunner { + private static final QueryIdGenerator queryIdGenerator = new QueryIdGenerator(); + private final EventListenerManager eventListenerManager = new EventListenerManager(new EventListenerConfig()); private final Session defaultSession; @@ -283,6 +286,7 @@ public class LocalQueryRunner private final PlanOptimizersProvider planOptimizersProvider; private final OperatorFactories operatorFactories; private final StatementAnalyzerFactory statementAnalyzerFactory; + private final TypeAnalyzer typeAnalyzer; private boolean printPlan; private final ReadWriteLock lock = new ReentrantReadWriteLock(); @@ -386,7 +390,7 @@ private LocalQueryRunner( tablePropertyManager, analyzePropertyManager, tableProceduresPropertyManager); - TypeAnalyzer typeAnalyzer = new TypeAnalyzer(plannerContext, statementAnalyzerFactory); + this.typeAnalyzer = new TypeAnalyzer(plannerContext, statementAnalyzerFactory); this.statsCalculator = createNewStatsCalculator(plannerContext, typeAnalyzer); this.scalarStatsCalculator = new ScalarStatsCalculator(plannerContext, typeAnalyzer); this.taskCountEstimator = new TaskCountEstimator(() -> nodeCountForStats); @@ -690,7 +694,10 @@ public ScheduledExecutorService getScheduler() @Override public Session getDefaultSession() { - return defaultSession; + return TestingSession + .testSessionBuilder(sessionPropertyManager, defaultSession, Optional.empty()) + .setQueryId(queryIdGenerator.createNextQueryId()) + .build(); } public ExpressionCompiler getExpressionCompiler() @@ -906,7 +913,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out tableExecuteContextManager.registerTableExecuteContextForQuery(taskContext.getQueryContext().getQueryId()); LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( plannerContext, - new TypeAnalyzer(plannerContext, statementAnalyzerFactory), + typeAnalyzer, Optional.empty(), pageSourceManager, indexManager, @@ -1025,7 +1032,7 @@ public List getPlanOptimizers(boolean forceSingleNode) { return planOptimizersProvider.getPlanOptimizers( plannerContext, - new TypeAnalyzer(plannerContext, statementAnalyzerFactory), + typeAnalyzer, taskManagerConfig, forceSingleNode, splitManager, @@ -1066,7 +1073,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List transactionId) + { + SessionBuilder builder = TestingSession.testSessionBuilder(sessionPropertyManager) + .setQueryId(session.getQueryId()) + .setIdentity(session.getIdentity()) + .setSource(session.getSource()) + .setCatalog(session.getCatalog()) + .setSchema(session.getSchema()) + .setPath(session.getPath()) + .setTraceToken(session.getTraceToken()) + .setTimeZoneKey(session.getTimeZoneKey()) + .setLocale(session.getLocale()) + .setRemoteUserAddress(session.getRemoteUserAddress()) + .setUserAgent(session.getUserAgent()) + .setClientInfo(session.getClientInfo()) + .setClientTags(session.getClientTags()) + .setClientCapabilities(session.getClientCapabilities()) + .setResourceEstimates(session.getResourceEstimates()) + .setStart(session.getStart()) + .setSystemProperties(session.getSystemProperties()) + .setCatalogProperties(session.getCatalogProperties()) + .setProtocolHeaders(session.getProtocolHeaders()); + + session.getPreparedStatements().forEach(builder::addPreparedStatement); + transactionId.ifPresent(builder::setTransactionId); + + if (session.isClientTransactionSupport()) { + builder.setClientTransactionSupport(); + } + + return builder; + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index a7c8104076d4..3109d761f598 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -14,6 +14,7 @@ package io.trino.sql; import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -30,8 +31,12 @@ import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.assertions.SymbolAliases; import io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.LikePredicate; +import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; @@ -71,6 +76,7 @@ import static io.trino.sql.ExpressionTestUtils.resolveFunctionCalls; import static io.trino.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.trino.sql.ParsingUtil.createParsingOptions; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; @@ -81,6 +87,7 @@ import static java.util.function.Function.identity; import static org.joda.time.DateTimeZone.UTC; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; public class TestExpressionInterpreter @@ -520,6 +527,18 @@ public void testIn() assertOptimizedEquals("0 / 0 in (2, 2)", "0 / 0 = 2"); } + @Test + public void testUnsimplifiedIn() + { + // should not create a new instance of InPredicate when IN list consists of literals only + InPredicate literalsInList = new InPredicate(new SymbolReference("unbound_integer"), new InListExpression(ImmutableList.of(new LongLiteral("42"), new LongLiteral("43")))); + assertSame(optimize(literalsInList), literalsInList); + + // should not create a new instance of InPredicate when IN list consists of unbounded expressions that cannot be simplified + InPredicate unboundedInList = new InPredicate(new SymbolReference("unbound_integer"), new InListExpression(ImmutableList.of(new SymbolReference("unbound_integer"), new Cast(new SymbolReference("unbound_long"), toSqlType(INTEGER))))); + assertSame(optimize(unboundedInList), unboundedInList); + } + @Test public void testInComplexTypes() { diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestLocalQueries.java b/testing/trino-tests/src/test/java/io/trino/tests/TestLocalQueries.java index 1370c4f8d099..b488811742ff 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestLocalQueries.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestLocalQueries.java @@ -29,6 +29,8 @@ import static io.trino.testing.MaterializedResult.resultBuilder; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.assertions.Assert.assertEquals; +import static java.util.stream.Collectors.joining; +import static java.util.stream.IntStream.range; public class TestLocalQueries extends AbstractTestQueries @@ -117,4 +119,14 @@ public void testTransformValuesInTry() "FROM (VALUES map(ARRAY[1, 2], ARRAY[0, 0]), map(ARRAY[28], ARRAY[2]), map(ARRAY[18], ARRAY[2]), map(ARRAY[4, 5], ARRAY[1, 0]), map(ARRAY[12], ARRAY[3])) AS t(m)", "VALUES NULL, '{\"28\":14}', '{\"18\":9}', NULL, '{\"12\":4}'"); } + + @Test + public void testExtremelyLargeIn() + { + // query should not fail + String longValues = range(0, 100000) + .mapToObj(Integer::toString) + .collect(joining(", ")); + computeActual("SELECT orderkey FROM orders WHERE orderkey IS NOT NULL OR orderkey IN (" + longValues + ")"); + } }