diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index f80055075ea2..01602660c666 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -119,8 +119,9 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ // TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite Map, Type> expressionTypes = getExpressionTypes(plannerContext, session, predicate, types); - ExpressionInterpreter interpreter = new ExpressionInterpreter(predicate, plannerContext, session, expressionTypes); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + // TODO - Use the same instance of ExpressionInterpreter create per planning once StatsRule has context + ExpressionInterpreter interpreter = new ExpressionInterpreter(plannerContext, session); + Object value = interpreter.optimize(predicate, expressionTypes, NoOpSymbolResolver.INSTANCE); if (value == null) { // Expression evaluates to SQL null, which in Filter is equivalent to false. This assumes the expression is a top-level expression (eg. not in NOT). diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index ec6ab4c6b196..eca3f89899d1 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -136,8 +136,9 @@ protected SymbolStatsEstimate visitLiteral(Literal node, Void context) protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) { Map, Type> expressionTypes = getExpressionTypes(plannerContext, session, node, types); - ExpressionInterpreter interpreter = new ExpressionInterpreter(node, plannerContext, session, expressionTypes); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + // TODO - Use the same instance of ExpressionInterpreter create per planning once StatsRule has context + ExpressionInterpreter interpreter = new ExpressionInterpreter(plannerContext, session); + Object value = interpreter.optimize(node, expressionTypes, NoOpSymbolResolver.INSTANCE); if (value == null || value instanceof NullLiteral) { return nullStatsEstimate(); diff --git a/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java b/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java index bb0427c258ec..67bb333dabcd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java @@ -304,8 +304,8 @@ public static boolean isEffectivelyLiteral(PlannerContext plannerContext, Sessio private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext plannerContext, Session session, Expression constantExpression) { Map, Type> types = getExpressionTypes(plannerContext, session, constantExpression, TypeProvider.empty()); - ExpressionInterpreter interpreter = new ExpressionInterpreter(constantExpression, plannerContext, session, types); - Object literalValue = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + ExpressionInterpreter interpreter = new ExpressionInterpreter(plannerContext, session); + Object literalValue = interpreter.optimize(constantExpression, types, NoOpSymbolResolver.INSTANCE); return !(literalValue instanceof Expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index f119639d8d36..f816ca75f715 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -2094,9 +2094,9 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s throw semanticException(TYPE_MISMATCH, samplePercentage, "Sample percentage should be a numeric expression"); } - ExpressionInterpreter samplePercentageEval = new ExpressionInterpreter(samplePercentage, plannerContext, session, expressionTypes); + ExpressionInterpreter samplePercentageEval = new ExpressionInterpreter(plannerContext, session); - Object samplePercentageObject = samplePercentageEval.optimize(symbol -> { + Object samplePercentageObject = samplePercentageEval.optimize(samplePercentage, expressionTypes, symbol -> { throw semanticException(EXPRESSION_NOT_CONSTANT, samplePercentage, "Sample percentage cannot contain column references"); }); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index 8823a74140ca..243525e7ea1c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -301,7 +301,12 @@ private static boolean isBetween(Range range) * 2) An Expression fragment which represents the part of the original Expression that will need to be re-evaluated * after filtering with the TupleDomain. */ - public static ExtractionResult getExtractionResult(PlannerContext plannerContext, Session session, Expression predicate, TypeProvider types) + public static ExtractionResult getExtractionResult( + PlannerContext plannerContext, + Session session, + Expression predicate, + TypeProvider types, + ExpressionInterpreter expressionInterpreter) { // This is a limited type analyzer for the simple expressions used in this method TypeAnalyzer typeAnalyzer = new TypeAnalyzer( @@ -316,7 +321,7 @@ public static ExtractionResult getExtractionResult(PlannerContext plannerContext new TablePropertyManager(), new AnalyzePropertyManager(), new TableProceduresPropertyManager())); - return new Visitor(plannerContext, session, types, typeAnalyzer).process(predicate, false); + return new Visitor(plannerContext, session, types, typeAnalyzer, expressionInterpreter).process(predicate, false); } private static class Visitor @@ -326,16 +331,23 @@ private static class Visitor private final LiteralEncoder literalEncoder; private final Session session; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private final InterpretedFunctionInvoker functionInvoker; private final TypeAnalyzer typeAnalyzer; private final TypeCoercion typeCoercion; - private Visitor(PlannerContext plannerContext, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) + private Visitor( + PlannerContext plannerContext, + Session session, + TypeProvider types, + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.literalEncoder = new LiteralEncoder(plannerContext); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager()); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); @@ -539,8 +551,8 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, private Optional toNormalizedSimpleComparison(ComparisonExpression comparison) { Map, Type> expressionTypes = analyzeExpression(comparison); - Object left = new ExpressionInterpreter(comparison.getLeft(), plannerContext, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE); - Object right = new ExpressionInterpreter(comparison.getRight(), plannerContext, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE); + Object left = expressionInterpreter.optimize(comparison.getLeft(), expressionTypes, NoOpSymbolResolver.INSTANCE); + Object right = expressionInterpreter.optimize(comparison.getRight(), expressionTypes, NoOpSymbolResolver.INSTANCE); Type leftType = expressionTypes.get(NodeRef.of(comparison.getLeft())); Type rightType = expressionTypes.get(NodeRef.of(comparison.getRight())); @@ -908,8 +920,7 @@ private Optional processSimpleInPredicate(InPredicate node, Bo List excludedExpressions = new ArrayList<>(); for (Expression expression : valueList.getValues()) { - Object value = new ExpressionInterpreter(expression, plannerContext, session, expressionTypes) - .optimize(NoOpSymbolResolver.INSTANCE); + Object value = expressionInterpreter.optimize(expression, expressionTypes, NoOpSymbolResolver.INSTANCE); if (value == null || value instanceof NullLiteral) { if (!complement) { // in case of IN, NULL on the right results with NULL comparison result (effectively false in predicate context), so can be ignored, as the diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index b94094323971..b1814739ff3f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -109,29 +109,45 @@ public EffectivePredicateExtractor(DomainTranslator domainTranslator, PlannerCon this.useTableProperties = useTableProperties; } - public Expression extract(Session session, PlanNode node, TypeProvider types, TypeAnalyzer typeAnalyzer) + public Expression extract(Session session, PlanNode node, TypeProvider types, TypeAnalyzer typeAnalyzer, ExpressionInterpreter expressionInterpreter) { - return node.accept(new Visitor(domainTranslator, plannerContext, session, types, typeAnalyzer, useTableProperties), null); + return node.accept( + new Visitor( + domainTranslator, + plannerContext.getMetadata(), + session, + types, + expressionInterpreter, + typeAnalyzer, + useTableProperties), + null); } private static class Visitor extends PlanVisitor { private final DomainTranslator domainTranslator; - private final PlannerContext plannerContext; private final Metadata metadata; private final Session session; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private final TypeAnalyzer typeAnalyzer; private final boolean useTableProperties; - public Visitor(DomainTranslator domainTranslator, PlannerContext plannerContext, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer, boolean useTableProperties) + public Visitor( + DomainTranslator domainTranslator, + Metadata metadata, + Session session, + TypeProvider types, + ExpressionInterpreter expressionInterpreter, + TypeAnalyzer typeAnalyzer, + boolean useTableProperties) { this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); - this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); - this.metadata = plannerContext.getMetadata(); + this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.useTableProperties = useTableProperties; } @@ -386,8 +402,7 @@ public Expression visitValues(ValuesNode node, Void context) nonDeterministic[i] = true; } else { - ExpressionInterpreter interpreter = new ExpressionInterpreter(value, plannerContext, session, expressionTypes); - Object item = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object item = expressionInterpreter.optimize(value, expressionTypes, NoOpSymbolResolver.INSTANCE); if (item instanceof Expression) { return TRUE_LITERAL; } @@ -413,8 +428,7 @@ public Expression visitValues(ValuesNode node, Void context) if (!DeterminismEvaluator.isDeterministic(row, metadata)) { return TRUE_LITERAL; } - ExpressionInterpreter interpreter = new ExpressionInterpreter(row, plannerContext, session, expressionTypes); - Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object evaluated = expressionInterpreter.optimize(row, expressionTypes, NoOpSymbolResolver.INSTANCE); if (evaluated instanceof Expression) { return TRUE_LITERAL; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index 1598cdc545ca..30705c704d19 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.MapMaker; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -104,7 +105,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; -import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -158,34 +158,35 @@ public class ExpressionInterpreter { - private final Expression expression; private final PlannerContext plannerContext; private final Metadata metadata; private final LiteralInterpreter literalInterpreter; private final LiteralEncoder literalEncoder; private final Session session; private final ConnectorSession connectorSession; - private final Map, Type> expressionTypes; private final InterpretedFunctionInvoker functionInvoker; private final TypeCoercion typeCoercion; - // identity-based cache for LIKE expressions with constant pattern and escape char - private final IdentityHashMap likePatternCache = new IdentityHashMap<>(); - private final IdentityHashMap> inListCache = new IdentityHashMap<>(); + // weak reference based cache for LIKE expressions with constant pattern and escape char + private final Map likePatternCache; + private final Map> inListCache; + private final Map optimizationCache; - public ExpressionInterpreter(Expression expression, PlannerContext plannerContext, Session session, Map, Type> expressionTypes) + public ExpressionInterpreter(PlannerContext plannerContext, Session session) { - this.expression = requireNonNull(expression, "expression is null"); this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.metadata = plannerContext.getMetadata(); this.literalInterpreter = new LiteralInterpreter(plannerContext, session); this.literalEncoder = new LiteralEncoder(plannerContext); this.session = requireNonNull(session, "session is null"); this.connectorSession = session.toConnectorSession(); - this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null")); - verify((expressionTypes.containsKey(NodeRef.of(expression)))); this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager()); this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); + + MapMaker mapMaker = new MapMaker().weakKeys().concurrencyLevel(1); + this.likePatternCache = mapMaker.makeMap(); + this.inListCache = mapMaker.makeMap(); + this.optimizationCache = mapMaker.makeMap(); } public static Object evaluateConstantExpression( @@ -258,51 +259,57 @@ public static Object evaluateConstantExpression( analyzer.analyze(resolved, Scope.create()); // evaluate the expression - return new ExpressionInterpreter(resolved, plannerContext, session, analyzer.getExpressionTypes()).evaluate(); + return new ExpressionInterpreter(plannerContext, session).evaluate(resolved, analyzer.getExpressionTypes()); } - public Type getType() + public Object evaluate(Expression expression, Map, Type> expressionTypes) { - return expressionTypes.get(NodeRef.of(expression)); - } + requireNonNull(expression, "expression is null"); + verify((expressionTypes.containsKey(NodeRef.of(expression)))); - public Object evaluate() - { - Object result = new Visitor(false).processWithExceptionHandling(expression, new NoPagePositionContext()); + Object result = new Visitor(expressionTypes, false).processWithExceptionHandling(expression, NoOpSymbolResolver.INSTANCE); verify(!(result instanceof Expression), "Expression interpreter returned an unresolved expression"); return result; } - public Object evaluate(SymbolResolver inputs) + public Object evaluate(Expression expression, Map, Type> expressionTypes, SymbolResolver inputs) { - Object result = new Visitor(false).processWithExceptionHandling(expression, inputs); + requireNonNull(expression, "expression is null"); + verify((expressionTypes.containsKey(NodeRef.of(expression)))); + + Object result = new Visitor(expressionTypes, false).processWithExceptionHandling(expression, inputs); verify(!(result instanceof Expression), "Expression interpreter returned an unresolved expression"); return result; } - public Object optimize(SymbolResolver inputs) + public Object optimize(Expression expression, Map, Type> expressionTypes, SymbolResolver inputs) { - return new Visitor(true).processWithExceptionHandling(expression, inputs); + requireNonNull(expression, "expression is null"); + verify((expressionTypes.containsKey(NodeRef.of(expression)))); + + return new Visitor(expressionTypes, true).processWithExceptionHandling(expression, inputs); } private class Visitor - extends AstVisitor + extends AstVisitor { private final boolean optimize; + private final Map, Type> expressionTypes; - private Visitor(boolean optimize) + private Visitor(Map, Type> expressionTypes, boolean optimize) { + this.expressionTypes = requireNonNull(expressionTypes, "expressionTypes is null"); this.optimize = optimize; } - private Object processWithExceptionHandling(Expression expression, Object context) + private Object processWithExceptionHandling(Expression expression, SymbolResolver resolver) { if (expression == null) { return null; } try { - return process(expression, context); + return processWithCaching(expression, resolver); } catch (TrinoException e) { if (optimize) { @@ -317,14 +324,27 @@ private Object processWithExceptionHandling(Expression expression, Object contex } } + private Object processWithCaching(Expression expression, SymbolResolver resolver) + { + if (optimize && resolver instanceof NoOpSymbolResolver) { + // We are using weak reference map as cache, that's why we can't depend on an intermediate object + // that consists of expression as well as symbolResolver as a key. This is because intermediate object could + // get GCed even when expression still exists which will cause cache miss. So, to solve this problem we + // can only cache when context is noop. It'll also give us good performance benefits because in most of + // places optimize is called with noop symbol resolver. + return optimizationCache.computeIfAbsent(expression, key -> process(key, resolver)); + } + return process(expression, resolver); + } + @Override - public Object visitFieldReference(FieldReference node, Object context) + public Object visitFieldReference(FieldReference node, SymbolResolver resolver) { throw new UnsupportedOperationException("Field references not supported in interpreter"); } @Override - protected Object visitDereferenceExpression(DereferenceExpression node, Object context) + protected Object visitDereferenceExpression(DereferenceExpression node, SymbolResolver resolver) { checkArgument(!isQualifiedAllFieldsReference(node), "unexpected expression: all fields labeled reference " + node); Identifier fieldIdentifier = node.getField().orElseThrow(); @@ -336,7 +356,7 @@ protected Object visitDereferenceExpression(DereferenceExpression node, Object c } // Row dereference: process dereference base eagerly, and only then pick the expected field - Object base = processWithExceptionHandling(node.getBase(), context); + Object base = processWithExceptionHandling(node.getBase(), resolver); // if the base part is evaluated to be null, the dereference expression should also be null if (base == null) { return null; @@ -365,37 +385,37 @@ protected Object visitDereferenceExpression(DereferenceExpression node, Object c } @Override - protected Object visitIdentifier(Identifier node, Object context) + protected Object visitIdentifier(Identifier node, SymbolResolver resolver) { // Identifier only exists before planning. // ExpressionInterpreter should only be invoked after planning. // As a result, this method should be unreachable. // However, RelationPlanner.visitUnnest and visitValues invokes evaluateConstantExpression. - return ((SymbolResolver) context).getValue(new Symbol(node.getValue())); + return resolver.getValue(new Symbol(node.getValue())); } @Override - protected Object visitParameter(Parameter node, Object context) + protected Object visitParameter(Parameter node, SymbolResolver resolver) { return node; } @Override - protected Object visitSymbolReference(SymbolReference node, Object context) + protected Object visitSymbolReference(SymbolReference node, SymbolResolver resolver) { - return ((SymbolResolver) context).getValue(Symbol.from(node)); + return resolver.getValue(Symbol.from(node)); } @Override - protected Object visitLiteral(Literal node, Object context) + protected Object visitLiteral(Literal node, SymbolResolver resolver) { return literalInterpreter.evaluate(node, type(node)); } @Override - protected Object visitIsNullPredicate(IsNullPredicate node, Object context) + protected Object visitIsNullPredicate(IsNullPredicate node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value instanceof Expression) { return new IsNullPredicate(toExpression(value, type(node.getValue()))); @@ -405,9 +425,9 @@ protected Object visitIsNullPredicate(IsNullPredicate node, Object context) } @Override - protected Object visitIsNotNullPredicate(IsNotNullPredicate node, Object context) + protected Object visitIsNotNullPredicate(IsNotNullPredicate node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value instanceof Expression) { return new IsNotNullPredicate(toExpression(value, type(node.getValue()))); @@ -417,25 +437,25 @@ protected Object visitIsNotNullPredicate(IsNotNullPredicate node, Object context } @Override - protected Object visitSearchedCaseExpression(SearchedCaseExpression node, Object context) + protected Object visitSearchedCaseExpression(SearchedCaseExpression node, SymbolResolver resolver) { Object newDefault = null; boolean foundNewDefault = false; List whenClauses = new ArrayList<>(); for (WhenClause whenClause : node.getWhenClauses()) { - Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); + Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), resolver); if (whenOperand instanceof Expression) { // cannot fully evaluate, add updated whenClause whenClauses.add(new WhenClause( toExpression(whenOperand, type(whenClause.getOperand())), - toExpression(processWithExceptionHandling(whenClause.getResult(), context), type(whenClause.getResult())))); + toExpression(processWithExceptionHandling(whenClause.getResult(), resolver), type(whenClause.getResult())))); } else if (Boolean.TRUE.equals(whenOperand)) { // condition is true, use this as default foundNewDefault = true; - newDefault = processWithExceptionHandling(whenClause.getResult(), context); + newDefault = processWithExceptionHandling(whenClause.getResult(), resolver); break; } } @@ -445,7 +465,7 @@ else if (Boolean.TRUE.equals(whenOperand)) { defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), resolver); } if (whenClauses.isEmpty()) { @@ -457,33 +477,33 @@ else if (Boolean.TRUE.equals(whenOperand)) { } @Override - protected Object visitIfExpression(IfExpression node, Object context) + protected Object visitIfExpression(IfExpression node, SymbolResolver resolver) { - Object condition = processWithExceptionHandling(node.getCondition(), context); + Object condition = processWithExceptionHandling(node.getCondition(), resolver); if (condition instanceof Expression) { - Object trueValue = processWithExceptionHandling(node.getTrueValue(), context); - Object falseValue = processWithExceptionHandling(node.getFalseValue().orElse(null), context); + Object trueValue = processWithExceptionHandling(node.getTrueValue(), resolver); + Object falseValue = processWithExceptionHandling(node.getFalseValue().orElse(null), resolver); return new IfExpression( toExpression(condition, type(node.getCondition())), toExpression(trueValue, type(node.getTrueValue())), (falseValue == null) ? null : toExpression(falseValue, type(node.getFalseValue().get()))); } if (Boolean.TRUE.equals(condition)) { - return processWithExceptionHandling(node.getTrueValue(), context); + return processWithExceptionHandling(node.getTrueValue(), resolver); } - return processWithExceptionHandling(node.getFalseValue().orElse(null), context); + return processWithExceptionHandling(node.getFalseValue().orElse(null), resolver); } @Override - protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object context) + protected Object visitSimpleCaseExpression(SimpleCaseExpression node, SymbolResolver resolver) { - Object operand = processWithExceptionHandling(node.getOperand(), context); + Object operand = processWithExceptionHandling(node.getOperand(), resolver); Type operandType = type(node.getOperand()); // if operand is null, return defaultValue if (operand == null) { - return processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + return processWithExceptionHandling(node.getDefaultValue().orElse(null), resolver); } Object newDefault = null; @@ -491,18 +511,18 @@ protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object con List whenClauses = new ArrayList<>(); for (WhenClause whenClause : node.getWhenClauses()) { - Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context); + Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), resolver); if (whenOperand instanceof Expression || operand instanceof Expression) { // cannot fully evaluate, add updated whenClause whenClauses.add(new WhenClause( toExpression(whenOperand, type(whenClause.getOperand())), - toExpression(processWithExceptionHandling(whenClause.getResult(), context), type(whenClause.getResult())))); + toExpression(processWithExceptionHandling(whenClause.getResult(), resolver), type(whenClause.getResult())))); } else if (whenOperand != null && isEqual(operand, operandType, whenOperand, type(whenClause.getOperand()))) { // condition is true, use this as default foundNewDefault = true; - newDefault = processWithExceptionHandling(whenClause.getResult(), context); + newDefault = processWithExceptionHandling(whenClause.getResult(), resolver); break; } } @@ -512,7 +532,7 @@ else if (whenOperand != null && isEqual(operand, operandType, whenOperand, type( defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), resolver); } if (whenClauses.isEmpty()) { @@ -534,9 +554,9 @@ private Type type(Expression expression) } @Override - protected Object visitCoalesceExpression(CoalesceExpression node, Object context) + protected Object visitCoalesceExpression(CoalesceExpression node, SymbolResolver resolver) { - List newOperands = processOperands(node, context); + List newOperands = processOperands(node, resolver); if (newOperands.isEmpty()) { return null; } @@ -548,12 +568,12 @@ protected Object visitCoalesceExpression(CoalesceExpression node, Object context .collect(toImmutableList())); } - private List processOperands(CoalesceExpression node, Object context) + private List processOperands(CoalesceExpression node, SymbolResolver resolver) { List newOperands = new ArrayList<>(); Set uniqueNewOperands = new HashSet<>(); for (Expression operand : node.getOperands()) { - Object value = processWithExceptionHandling(operand, context); + Object value = processWithExceptionHandling(operand, resolver); if (value instanceof CoalesceExpression) { // The nested CoalesceExpression was recursively processed. It does not contain null. for (Expression nestedOperand : ((CoalesceExpression) value).getOperands()) { @@ -588,9 +608,9 @@ else if (value != null) { } @Override - protected Object visitInPredicate(InPredicate node, Object context) + protected Object visitInPredicate(InPredicate node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); Expression valueListExpression = node.getValueList(); if (!(valueListExpression instanceof InListExpression)) { @@ -611,18 +631,18 @@ protected Object visitInPredicate(InPredicate node, Object context) // We use the presence of the node in the map to indicate that we've already done // the analysis below. If the value is null, it means that we can't apply the HashSet // optimization - if (!inListCache.containsKey(valueList)) { + if (set == null) { if (valueList.getValues().stream().allMatch(Literal.class::isInstance) && valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) { - Set objectSet = valueList.getValues().stream().map(expression -> processWithExceptionHandling(expression, context)).collect(Collectors.toSet()); + Set objectSet = valueList.getValues().stream().map(expression -> processWithExceptionHandling(expression, resolver)).collect(Collectors.toSet()); Type type = type(node.getValue()); set = FastutilSetHelper.toFastutilHashSet( objectSet, type, plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); + inListCache.put(valueList, set); } - inListCache.put(valueList, set); } if (set != null) { @@ -651,7 +671,7 @@ protected Object visitInPredicate(InPredicate node, Object context) // but fail the whole in-predicate evaluation. // According to in-predicate semantics, all in-list items must be successfully evaluated // before a check for the match is performed. - Object inValue = process(expression, context); + Object inValue = process(expression, resolver); if (value instanceof Expression || inValue instanceof Expression) { hasUnresolvedValue = true; values.add(inValue); @@ -692,7 +712,24 @@ else if (!found && result) { return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toExpression(value, type), simplifiedExpressionValues.get(0)); } - return new InPredicate(toExpression(value, type), new InListExpression(simplifiedExpressionValues)); + Expression simplifiedValue = toExpression(value, type); + if (simplifiedValue.equals(node.getValue())) { + simplifiedValue = node.getValue(); + } + + Expression simplifiedValueList = new InListExpression(simplifiedExpressionValues); + if (simplifiedValueList.equals(node.getValueList())) { + simplifiedValueList = node.getValueList(); + } + + if (simplifiedValue == node.getValue() && simplifiedValueList == node.getValueList()) { + // Do not create a new instance of InPredicate expression if it would be same as original expression. + // Creating a new instance of InPredicate would cause inListCache cache miss, which is using node + // reference as a cache key. + return node; + } + + return new InPredicate(simplifiedValue, simplifiedValueList); } if (hasNullValue) { return null; @@ -701,7 +738,7 @@ else if (!found && result) { } @Override - protected Object visitExists(ExistsPredicate node, Object context) + protected Object visitExists(ExistsPredicate node, SymbolResolver resolver) { if (!optimize) { throw new UnsupportedOperationException("Exists subquery not yet implemented"); @@ -710,7 +747,7 @@ protected Object visitExists(ExistsPredicate node, Object context) } @Override - protected Object visitSubqueryExpression(SubqueryExpression node, Object context) + protected Object visitSubqueryExpression(SubqueryExpression node, SymbolResolver resolver) { if (!optimize) { throw new UnsupportedOperationException("Subquery not yet implemented"); @@ -719,9 +756,9 @@ protected Object visitSubqueryExpression(SubqueryExpression node, Object context } @Override - protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object context) + protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value == null) { return null; } @@ -765,13 +802,13 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con } @Override - protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object context) + protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, SymbolResolver resolver) { - Object left = processWithExceptionHandling(node.getLeft(), context); + Object left = processWithExceptionHandling(node.getLeft(), resolver); if (left == null) { return null; } - Object right = processWithExceptionHandling(node.getRight(), context); + Object right = processWithExceptionHandling(node.getRight(), resolver); if (right == null) { return null; } @@ -784,20 +821,20 @@ protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object c } @Override - protected Object visitComparisonExpression(ComparisonExpression node, Object context) + protected Object visitComparisonExpression(ComparisonExpression node, SymbolResolver resolver) { ComparisonExpression.Operator operator = node.getOperator(); Expression left = node.getLeft(); Expression right = node.getRight(); if (operator == Operator.IS_DISTINCT_FROM) { - return processIsDistinctFrom(context, left, right); + return processIsDistinctFrom(resolver, left, right); } // Execution engine does not have not equal and greater than operators, so interpret with // equal or less than, but do not flip operator in result, as many optimizers depend on // operators not flipping if (node.getOperator() == Operator.NOT_EQUAL) { - Object result = visitComparisonExpression(flipComparison(node), context); + Object result = visitComparisonExpression(flipComparison(node), resolver); if (result == null) { return null; } @@ -807,20 +844,20 @@ protected Object visitComparisonExpression(ComparisonExpression node, Object con return !(Boolean) result; } if (node.getOperator() == Operator.GREATER_THAN || node.getOperator() == Operator.GREATER_THAN_OR_EQUAL) { - Object result = visitComparisonExpression(flipComparison(node), context); + Object result = visitComparisonExpression(flipComparison(node), resolver); if (result instanceof ComparisonExpression) { return flipComparison((ComparisonExpression) result); } return result; } - return processComparisonExpression(context, operator, left, right); + return processComparisonExpression(resolver, operator, left, right); } - private Object processIsDistinctFrom(Object context, Expression leftExpression, Expression rightExpression) + private Object processIsDistinctFrom(SymbolResolver resolver, Expression leftExpression, Expression rightExpression) { - Object left = processWithExceptionHandling(leftExpression, context); - Object right = processWithExceptionHandling(rightExpression, context); + Object left = processWithExceptionHandling(leftExpression, resolver); + Object right = processWithExceptionHandling(rightExpression, resolver); if (left == null && right instanceof Expression) { return new IsNotNullPredicate((Expression) right); @@ -837,14 +874,14 @@ private Object processIsDistinctFrom(Object context, Expression leftExpression, return invokeOperator(OperatorType.valueOf(Operator.IS_DISTINCT_FROM.name()), types(leftExpression, rightExpression), Arrays.asList(left, right)); } - private Object processComparisonExpression(Object context, Operator operator, Expression leftExpression, Expression rightExpression) + private Object processComparisonExpression(SymbolResolver resolver, Operator operator, Expression leftExpression, Expression rightExpression) { - Object left = processWithExceptionHandling(leftExpression, context); + Object left = processWithExceptionHandling(leftExpression, resolver); if (left == null) { return null; } - Object right = processWithExceptionHandling(rightExpression, context); + Object right = processWithExceptionHandling(rightExpression, resolver); if (right == null) { return null; } @@ -879,14 +916,14 @@ private ComparisonExpression flipComparison(ComparisonExpression comparisonExpre } @Override - protected Object visitBetweenPredicate(BetweenPredicate node, Object context) + protected Object visitBetweenPredicate(BetweenPredicate node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value == null) { return null; } - Object min = processWithExceptionHandling(node.getMin(), context); - Object max = processWithExceptionHandling(node.getMax(), context); + Object min = processWithExceptionHandling(node.getMin(), resolver); + Object max = processWithExceptionHandling(node.getMax(), resolver); if (value instanceof Expression || min instanceof Expression || max instanceof Expression) { return new BetweenPredicate( @@ -914,13 +951,13 @@ protected Object visitBetweenPredicate(BetweenPredicate node, Object context) } @Override - protected Object visitNullIfExpression(NullIfExpression node, Object context) + protected Object visitNullIfExpression(NullIfExpression node, SymbolResolver resolver) { - Object first = processWithExceptionHandling(node.getFirst(), context); + Object first = processWithExceptionHandling(node.getFirst(), resolver); if (first == null) { return null; } - Object second = processWithExceptionHandling(node.getSecond(), context); + Object second = processWithExceptionHandling(node.getSecond(), resolver); if (second == null) { return first; } @@ -952,9 +989,9 @@ protected Object visitNullIfExpression(NullIfExpression node, Object context) } @Override - protected Object visitNotExpression(NotExpression node, Object context) + protected Object visitNotExpression(NotExpression node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value == null) { return null; } @@ -967,13 +1004,13 @@ protected Object visitNotExpression(NotExpression node, Object context) } @Override - protected Object visitLogicalExpression(LogicalExpression node, Object context) + protected Object visitLogicalExpression(LogicalExpression node, SymbolResolver resolver) { List terms = new ArrayList<>(); List types = new ArrayList<>(); for (Expression term : node.getTerms()) { - Object processed = processWithExceptionHandling(term, context); + Object processed = processWithExceptionHandling(term, resolver); switch (node.getOperator()) { case AND: @@ -1027,18 +1064,18 @@ protected Object visitLogicalExpression(LogicalExpression node, Object context) } @Override - protected Object visitBooleanLiteral(BooleanLiteral node, Object context) + protected Object visitBooleanLiteral(BooleanLiteral node, SymbolResolver resolver) { return node.equals(BooleanLiteral.TRUE_LITERAL); } @Override - protected Object visitFunctionCall(FunctionCall node, Object context) + protected Object visitFunctionCall(FunctionCall node, SymbolResolver resolver) { List argumentTypes = new ArrayList<>(); List argumentValues = new ArrayList<>(); for (Expression expression : node.getArguments()) { - Object value = processWithExceptionHandling(expression, context); + Object value = processWithExceptionHandling(expression, resolver); Type type = type(expression); argumentValues.add(value); argumentTypes.add(type); @@ -1071,7 +1108,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context) } @Override - protected Object visitLambdaExpression(LambdaExpression node, Object context) + protected Object visitLambdaExpression(LambdaExpression node, SymbolResolver resolver) { if (optimize) { // TODO: enable optimization related to lambda expression @@ -1098,12 +1135,12 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context) } @Override - protected Object visitBindExpression(BindExpression node, Object context) + protected Object visitBindExpression(BindExpression node, SymbolResolver resolver) { List values = node.getValues().stream() - .map(value -> processWithExceptionHandling(value, context)) + .map(value -> processWithExceptionHandling(value, resolver)) .collect(toList()); // values are nullable - Object function = processWithExceptionHandling(node.getFunction(), context); + Object function = processWithExceptionHandling(node.getFunction(), resolver); if (hasUnresolvedValue(values) || hasUnresolvedValue(function)) { ImmutableList.Builder builder = ImmutableList.builder(); @@ -1120,9 +1157,9 @@ protected Object visitBindExpression(BindExpression node, Object context) } @Override - protected Object visitLikePredicate(LikePredicate node, Object context) + protected Object visitLikePredicate(LikePredicate node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getValue(), context); + Object value = processWithExceptionHandling(node.getValue(), resolver); if (value == null) { return null; @@ -1135,7 +1172,7 @@ protected Object visitLikePredicate(LikePredicate node, Object context) return evaluateLikePredicate(node, (Slice) value, getConstantPattern(node)); } - Object pattern = processWithExceptionHandling(node.getPattern(), context); + Object pattern = processWithExceptionHandling(node.getPattern(), resolver); if (pattern == null) { return null; @@ -1143,7 +1180,7 @@ protected Object visitLikePredicate(LikePredicate node, Object context) Object escape = null; if (node.getEscape().isPresent()) { - escape = processWithExceptionHandling(node.getEscape().get(), context); + escape = processWithExceptionHandling(node.getEscape().get(), resolver); if (escape == null) { return null; @@ -1238,9 +1275,9 @@ private JoniRegexp getConstantPattern(LikePredicate node) } @Override - public Object visitCast(Cast node, Object context) + public Object visitCast(Cast node, SymbolResolver resolver) { - Object value = processWithExceptionHandling(node.getExpression(), context); + Object value = processWithExceptionHandling(node.getExpression(), resolver); Type targetType = plannerContext.getTypeManager().getType(toTypeSignature(node.getType())); Type sourceType = type(node.getExpression()); if (value instanceof Expression) { @@ -1273,13 +1310,13 @@ public Object visitCast(Cast node, Object context) } @Override - protected Object visitArrayConstructor(ArrayConstructor node, Object context) + protected Object visitArrayConstructor(ArrayConstructor node, SymbolResolver resolver) { Type elementType = ((ArrayType) type(node)).getElementType(); BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(null, node.getValues().size()); for (Expression expression : node.getValues()) { - Object value = processWithExceptionHandling(expression, context); + Object value = processWithExceptionHandling(expression, resolver); if (value instanceof Expression) { checkCondition(node.getValues().size() <= 254, NOT_SUPPORTED, "Too many arguments for array constructor"); return visitFunctionCall( @@ -1287,7 +1324,7 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context) .setName(QualifiedName.of(ArrayConstructor.ARRAY_CONSTRUCTOR)) .setArguments(types(node.getValues()), node.getValues()) .build(), - context); + resolver); } writeNativeValue(elementType, arrayBlockBuilder, value); } @@ -1296,31 +1333,31 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context) } @Override - protected Object visitCurrentCatalog(CurrentCatalog node, Object context) + protected Object visitCurrentCatalog(CurrentCatalog node, SymbolResolver resolver) { - return visitFunctionCall(desugarCurrentCatalog(session, node, metadata), context); + return visitFunctionCall(desugarCurrentCatalog(session, node, metadata), resolver); } @Override - protected Object visitCurrentSchema(CurrentSchema node, Object context) + protected Object visitCurrentSchema(CurrentSchema node, SymbolResolver resolver) { - return visitFunctionCall(desugarCurrentSchema(session, node, metadata), context); + return visitFunctionCall(desugarCurrentSchema(session, node, metadata), resolver); } @Override - protected Object visitCurrentUser(CurrentUser node, Object context) + protected Object visitCurrentUser(CurrentUser node, SymbolResolver resolver) { - return visitFunctionCall(DesugarCurrentUser.getCall(node, metadata, session), context); + return visitFunctionCall(DesugarCurrentUser.getCall(node, metadata, session), resolver); } @Override - protected Object visitCurrentPath(CurrentPath node, Object context) + protected Object visitCurrentPath(CurrentPath node, SymbolResolver resolver) { - return visitFunctionCall(DesugarCurrentPath.getCall(node, metadata, session), context); + return visitFunctionCall(DesugarCurrentPath.getCall(node, metadata, session), resolver); } @Override - protected Object visitRow(Row node, Object context) + protected Object visitRow(Row node, SymbolResolver resolver) { RowType rowType = (RowType) type(node); List parameterTypes = rowType.getTypeParameters(); @@ -1329,7 +1366,7 @@ protected Object visitRow(Row node, Object context) int cardinality = arguments.size(); List values = new ArrayList<>(cardinality); for (Expression argument : arguments) { - values.add(processWithExceptionHandling(argument, context)); + values.add(processWithExceptionHandling(argument, resolver)); } if (hasUnresolvedValue(values)) { return new Row(toExpressions(values, parameterTypes)); @@ -1344,13 +1381,13 @@ protected Object visitRow(Row node, Object context) } @Override - protected Object visitSubscriptExpression(SubscriptExpression node, Object context) + protected Object visitSubscriptExpression(SubscriptExpression node, SymbolResolver resolver) { - Object base = processWithExceptionHandling(node.getBase(), context); + Object base = processWithExceptionHandling(node.getBase(), resolver); if (base == null) { return null; } - Object index = processWithExceptionHandling(node.getIndex(), context); + Object index = processWithExceptionHandling(node.getIndex(), resolver); if (index == null) { return null; } @@ -1378,7 +1415,7 @@ protected Object visitSubscriptExpression(SubscriptExpression node, Object conte } @Override - protected Object visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, Object context) + protected Object visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, SymbolResolver resolver) { if (!optimize) { throw new UnsupportedOperationException("QuantifiedComparison not yet implemented"); @@ -1387,13 +1424,13 @@ protected Object visitQuantifiedComparisonExpression(QuantifiedComparisonExpress } @Override - protected Object visitExpression(Expression node, Object context) + protected Object visitExpression(Expression node, SymbolResolver resolver) { throw new TrinoException(NOT_SUPPORTED, "not yet implemented: " + node.getClass().getName()); } @Override - protected Object visitNode(Node node, Object context) + protected Object visitNode(Node node, SymbolResolver resolver) { throw new UnsupportedOperationException("Evaluator visitor can only handle Expression nodes"); } @@ -1441,29 +1478,6 @@ private List toExpressions(List values, List types) } } - private interface PagePositionContext - { - Block getBlock(int channel); - - int getPosition(int channel); - } - - private static class NoPagePositionContext - implements PagePositionContext - { - @Override - public Block getBlock(int channel) - { - throw new IllegalArgumentException("Context does not contain any blocks"); - } - - @Override - public int getPosition(int channel) - { - throw new IllegalArgumentException("Context does not have a position"); - } - } - private static boolean isArray(Type type) { return type instanceof ArrayType; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java index 8b8851c4128d..393cdb107356 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java @@ -18,8 +18,10 @@ import io.trino.operator.scalar.TryFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.NullableValue; +import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import java.util.Map; @@ -33,12 +35,16 @@ public class LayoutConstraintEvaluator { private final Map assignments; private final ExpressionInterpreter evaluator; + private final Expression expression; + private final Map, Type> expressionTypes; private final Set arguments; public LayoutConstraintEvaluator(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, Map assignments, Expression expression) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); - evaluator = new ExpressionInterpreter(expression, plannerContext, session, typeAnalyzer.getTypes(session, types, expression)); + this.evaluator = new ExpressionInterpreter(plannerContext, session); + this.expression = requireNonNull(expression, "expression is null"); + this.expressionTypes = typeAnalyzer.getTypes(session, types, expression); arguments = SymbolsExtractor.extractUnique(expression).stream() .map(assignments::get) .collect(toImmutableSet()); @@ -58,7 +64,7 @@ public boolean isCandidate(Map bindings) // Skip pruning if evaluation fails in a recoverable way. Failing here can cause // spurious query failures for partitions that would otherwise be filtered out. - Object optimized = TryFunction.evaluate(() -> evaluator.optimize(inputs), true); + Object optimized = TryFunction.evaluate(() -> evaluator.optimize(expression, expressionTypes, inputs), true); // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned return !(Boolean.FALSE.equals(optimized) || optimized == null || optimized instanceof NullLiteral); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index c9cec05ec5af..07e49831444a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -2059,7 +2059,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext Map, Type> types = typeAnalyzer.getTypes(session, TypeProvider.empty(), row); checkState(types.get(NodeRef.of(row)) instanceof RowType, "unexpected type of Values row: %s", types); // evaluate the literal value - Object result = new ExpressionInterpreter(row, plannerContext, session, types).evaluate(); + Object result = new ExpressionInterpreter(plannerContext, session).evaluate(row, types); for (int j = 0; j < outputTypes.size(); j++) { // divide row into fields writeNativeValue(outputTypes.get(j), pageBuilder.getBlockBuilder(j), readNativeValue(outputTypes.get(j), (SingleRowBlock) result, j)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index e59ea95741ad..fc03603d7bb6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -235,8 +235,9 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) planSanityChecker.validateIntermediatePlan(root, session, plannerContext, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); if (stage.ordinal() >= OPTIMIZED.ordinal()) { + PlanOptimizer.Context context = createOptimizerContext(new ExpressionInterpreter(plannerContext, session)); for (PlanOptimizer optimizer : planOptimizers) { - root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector); + root = optimizer.optimize(root, context); requireNonNull(root, format("%s returned a null plan", optimizer.getClass().getName())); if (LOG.isDebugEnabled()) { @@ -269,6 +270,42 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) return new Plan(root, types, statsAndCosts); } + private PlanOptimizer.Context createOptimizerContext(ExpressionInterpreter interpreter) + { + return new PlanOptimizer.Context() + { + @Override + public Session getSession() + { + return session; + } + + @Override + public SymbolAllocator getSymbolAllocator() + { + return symbolAllocator; + } + + @Override + public PlanNodeIdAllocator getIdAllocator() + { + return idAllocator; + } + + @Override + public WarningCollector getWarningCollector() + { + return warningCollector; + } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return interpreter; + } + }; + } + public PlanNode planStatement(Analysis analysis, Statement statement) { if ((statement instanceof CreateTableAsSelect && analysis.getCreate().orElseThrow().isCreateTableAsSelectNoOp()) || diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java index 2d9bb1d0ba7e..d82094cb77a3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java @@ -31,10 +31,10 @@ import io.trino.matching.Pattern; import io.trino.spi.TrinoException; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.planprinter.PlanPrinter; @@ -98,22 +98,31 @@ public IterativeOptimizer(PlannerContext plannerContext, RuleStatsRecorder stats } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, PlanOptimizer.Context optimizerContext) { // only disable new rules if we have legacy rules to fall back to - if (useLegacyRules.test(session) && !legacyRules.isEmpty()) { + if (useLegacyRules.test(optimizerContext.getSession()) && !legacyRules.isEmpty()) { for (PlanOptimizer optimizer : legacyRules) { - plan = optimizer.optimize(plan, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector); + plan = optimizer.optimize(plan, optimizerContext); } return plan; } - Memo memo = new Memo(idAllocator, plan); + Memo memo = new Memo(optimizerContext.getIdAllocator(), plan); Lookup lookup = Lookup.from(planNode -> Stream.of(memo.resolve(planNode))); - Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); - Context context = new Context(memo, lookup, idAllocator, symbolAllocator, nanoTime(), timeout.toMillis(), session, warningCollector); + Duration timeout = SystemSessionProperties.getOptimizerTimeout(optimizerContext.getSession()); + Context context = new Context( + memo, + lookup, + optimizerContext.getIdAllocator(), + optimizerContext.getSymbolAllocator(), + nanoTime(), + timeout.toMillis(), + optimizerContext.getSession(), + optimizerContext.getWarningCollector(), + optimizerContext.getExpressionInterpreter()); exploreGroup(memo.getRootGroup(), context); return memo.extract(); @@ -310,6 +319,12 @@ public WarningCollector getWarningCollector() { return context.warningCollector; } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return context.expressionInterpreter; + } }; } @@ -323,6 +338,7 @@ private static class Context private final long timeoutInMilliseconds; private final Session session; private final WarningCollector warningCollector; + private final ExpressionInterpreter expressionInterpreter; private final Map, RuleInvocationStats> ruleStats = new HashMap<>(); @@ -334,7 +350,8 @@ public Context( long startTimeInNanos, long timeoutInMilliseconds, Session session, - WarningCollector warningCollector) + WarningCollector warningCollector, + ExpressionInterpreter expressionInterpreter) { checkArgument(timeoutInMilliseconds >= 0, "Timeout has to be a non-negative number [milliseconds]"); @@ -346,6 +363,7 @@ public Context( this.timeoutInMilliseconds = timeoutInMilliseconds; this.session = session; this.warningCollector = warningCollector; + this.expressionInterpreter = expressionInterpreter; } public void checkTimeoutNotExhausted() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/Rule.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/Rule.java index b021805ba900..8473d5d46bbe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/Rule.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/Rule.java @@ -19,6 +19,7 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.plan.PlanNode; @@ -58,6 +59,8 @@ interface Context void checkTimeoutNotExhausted(); WarningCollector getWarningCollector(); + + ExpressionInterpreter getExpressionInterpreter(); } final class Result diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 18b794729b37..cd87493b6dc8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -342,7 +342,14 @@ private StreamProperties derivePropertiesRecursively(PlanNode node, Context cont List inputProperties = resolvedPlanNode.getSources().stream() .map(source -> derivePropertiesRecursively(source, context)) .collect(toImmutableList()); - return deriveProperties(resolvedPlanNode, inputProperties, plannerContext, context.getSession(), context.getSymbolAllocator().getTypes(), typeAnalyzer); + return deriveProperties( + resolvedPlanNode, + inputProperties, + plannerContext, + context.getSession(), + context.getSymbolAllocator().getTypes(), + typeAnalyzer, + context.getExpressionInterpreter()); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index b38a263eb34e..48933d983d8e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -33,7 +33,6 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.planner.ConnectorExpressionTranslator; -import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.OrderingScheme; @@ -197,8 +196,7 @@ public static Optional pushAggregationIntoTableScan( Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, context.getSymbolAllocator().getTypes(), translated); translated = literalEncoder.toExpression( session, - new ExpressionInterpreter(translated, plannerContext, session, translatedExpressionTypes) - .optimize(NoOpSymbolResolver.INSTANCE), + context.getExpressionInterpreter().optimize(translated, translatedExpressionTypes, NoOpSymbolResolver.INSTANCE), translatedExpressionTypes.get(NodeRef.of(translated))); return translated; }) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java index e9f62749b2f1..83cb7c2ae9f4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java @@ -180,7 +180,12 @@ private static Result pushFilter(FilterNode filterNode, AggregationNode aggregat Symbol countSymbol = getOnlyElement(aggregationNode.getAggregations().keySet()); Aggregation aggregation = getOnlyElement(aggregationNode.getAggregations().values()); - DomainTranslator.ExtractionResult extractionResult = getExtractionResult(plannerContext, context.getSession(), filterNode.getPredicate(), context.getSymbolAllocator().getTypes()); + DomainTranslator.ExtractionResult extractionResult = getExtractionResult( + plannerContext, + context.getSession(), + filterNode.getPredicate(), + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter()); TupleDomain tupleDomain = extractionResult.getTupleDomain(); if (tupleDomain.isNone()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index d6428912a049..d56936cc0816 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -123,7 +123,8 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) plannerContext, typeAnalyzer, context.getStatsProvider(), - new DomainTranslator(plannerContext)); + new DomainTranslator(plannerContext), + context.getExpressionInterpreter()); if (rewritten.isEmpty() || arePlansSame(filterNode, tableScan, rewritten.get())) { return Result.empty(); @@ -162,7 +163,8 @@ public static Optional pushFilterIntoTableScan( PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, StatsProvider statsProvider, - DomainTranslator domainTranslator) + DomainTranslator domainTranslator, + ExpressionInterpreter expressionInterpreter) { if (!isAllowPushdownIntoConnectors(session)) { return Optional.empty(); @@ -174,7 +176,8 @@ public static Optional pushFilterIntoTableScan( plannerContext, session, splitExpression.getDeterministicPredicate(), - symbolAllocator.getTypes()); + symbolAllocator.getTypes(), + expressionInterpreter); TupleDomain newDomain = decomposedPredicate.getTupleDomain() .transformKeys(node.getAssignments()::get) @@ -225,6 +228,7 @@ public static Optional pushFilterIntoTableScan( session, symbolAllocator, typeAnalyzer, + expressionInterpreter, splitExpression.getDynamicFilter(), TRUE_LITERAL, splitExpression.getNonDeterministicPredicate(), @@ -289,8 +293,7 @@ public static Optional pushFilterIntoTableScan( Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), translatedExpression); translatedExpression = literalEncoder.toExpression( session, - new ExpressionInterpreter(translatedExpression, plannerContext, session, translatedExpressionTypes) - .optimize(NoOpSymbolResolver.INSTANCE), + expressionInterpreter.optimize(translatedExpression, translatedExpressionTypes, NoOpSymbolResolver.INSTANCE), translatedExpressionTypes.get(NodeRef.of(translatedExpression))); remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.getRemainingExpression()); } @@ -300,6 +303,7 @@ public static Optional pushFilterIntoTableScan( session, symbolAllocator, typeAnalyzer, + expressionInterpreter, splitExpression.getDynamicFilter(), domainTranslator.toPredicate(session, remainingFilter.transformKeys(assignments::get)), splitExpression.getNonDeterministicPredicate(), @@ -393,6 +397,7 @@ static Expression createResultingPredicate( Session session, SymbolAllocator symbolAllocator, TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter, Expression dynamicFilter, Expression unenforcedConstraints, Expression nonDeterministicPredicate, @@ -411,7 +416,7 @@ static Expression createResultingPredicate( // Make sure we produce an expression whose terms are consistent with the canonical form used in other optimizations // Otherwise, we'll end up ping-ponging among rules - expression = SimplifyExpressions.rewrite(expression, session, symbolAllocator, plannerContext, typeAnalyzer); + expression = SimplifyExpressions.rewrite(expression, session, symbolAllocator, plannerContext, typeAnalyzer, expressionInterpreter); return expression; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java index 515bf9b5f2d2..f6932308b620 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java @@ -108,7 +108,8 @@ public Result apply(FilterNode filter, Captures captures, Context context) plannerContext, context.getSession(), filter.getPredicate(), - context.getSymbolAllocator().getTypes()); + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter()); TupleDomain tupleDomain = extractionResult.getTupleDomain(); OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol); if (upperBound.isEmpty()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java index 2d8b14694065..6bb41b98a626 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java @@ -120,7 +120,8 @@ public Result apply(FilterNode filter, Captures captures, Context context) plannerContext, context.getSession(), filter.getPredicate(), - context.getSymbolAllocator().getTypes()); + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter()); TupleDomain tupleDomain = extractionResult.getTupleDomain(); OptionalInt upperBound = extractUpperBound(tupleDomain, rankingSymbol); if (upperBound.isEmpty()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index 67f0fa56b94c..4420c28b9944 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -33,7 +33,6 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.planner.ConnectorExpressionTranslator; -import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; @@ -158,8 +157,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, context.getSymbolAllocator().getTypes(), translated); translated = literalEncoder.toExpression( session, - new ExpressionInterpreter(translated, plannerContext, session, translatedExpressionTypes) - .optimize(NoOpSymbolResolver.INSTANCE), + context.getExpressionInterpreter().optimize(translated, translatedExpressionTypes, NoOpSymbolResolver.INSTANCE), translatedExpressionTypes.get(NodeRef.of(translated))); return translated; }) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java index cfce0e542cfc..80ef35453376 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java @@ -72,7 +72,12 @@ public Result apply(FilterNode node, Captures captures, Context context) Session session = context.getSession(); TypeProvider types = context.getSymbolAllocator().getTypes(); - DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, session, node.getPredicate(), types); + DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult( + plannerContext, + session, + node.getPredicate(), + types, + context.getExpressionInterpreter()); TupleDomain tupleDomain = extractionResult.getTupleDomain(); RowNumberNode source = captures.get(CHILD); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java index 7180c332757d..11547de5efcf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java @@ -90,7 +90,12 @@ public Result apply(FilterNode node, Captures captures, Context context) WindowNode windowNode = captures.get(childCapture); - DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, session, node.getPredicate(), types); + DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult( + plannerContext, + session, + node.getPredicate(), + types, + context.getExpressionInterpreter()); TupleDomain tupleDomain = extractionResult.getTupleDomain(); Optional rankingType = toTopNRankingType(windowNode); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java index f43a8dbd6a65..5051ae4b4058 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java @@ -26,6 +26,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.DomainTranslator.ExtractionResult; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; @@ -96,7 +97,8 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) ExtractionResult decomposedPredicate = getFullyExtractedPredicates( session, deterministicPredicate, - context.getSymbolAllocator().getTypes()); + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter()); if (decomposedPredicate.getTupleDomain().isAll()) { // no conjunct could be fully converted to tuple domain @@ -141,6 +143,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) session, context.getSymbolAllocator(), typeAnalyzer, + context.getExpressionInterpreter(), TRUE_LITERAL, // Dynamic filters are included in decomposedPredicate.getRemainingExpression() new DomainTranslator(plannerContext).toPredicate(session, unenforcedDomain.transformKeys(assignments::get)), nonDeterministicPredicate, @@ -153,10 +156,14 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) return Result.ofPlanNode(node); } - private ExtractionResult getFullyExtractedPredicates(Session session, Expression predicate, TypeProvider types) + private ExtractionResult getFullyExtractedPredicates( + Session session, + Expression predicate, + TypeProvider types, + ExpressionInterpreter expressionInterpreter) { Map> extractedPredicates = extractConjuncts(predicate).stream() - .map(conjunct -> DomainTranslator.getExtractionResult(plannerContext, session, conjunct, types)) + .map(conjunct -> DomainTranslator.getExtractionResult(plannerContext, session, conjunct, types, expressionInterpreter)) .collect(groupingBy(result -> result.getRemainingExpression().equals(TRUE_LITERAL), toList())); return new ExtractionResult( intersect(extractedPredicates.getOrDefault(TRUE, ImmutableList.of()).stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index 8036942bda29..17794d783af6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AnalyzePropertyManager; import io.trino.metadata.OperatorNotFoundException; import io.trino.metadata.SessionPropertyManager; @@ -29,9 +28,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.StatementAnalyzerFactory; import io.trino.sql.parser.SqlParser; -import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.optimizations.PlanOptimizer; @@ -104,9 +101,11 @@ public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - PlanWithConsumedDynamicFilters result = plan.accept(new RemoveUnsupportedDynamicFilters.Rewriter(session, types), ImmutableSet.of()); + PlanWithConsumedDynamicFilters result = plan.accept( + new RemoveUnsupportedDynamicFilters.Rewriter(context.getSession(), context.getSymbolAllocator().getTypes()), + ImmutableSet.of()); return result.getNode(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java index 33a16fc214fe..bb333222a411 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java @@ -38,10 +38,17 @@ public class SimplifyExpressions extends ExpressionRewriteRuleSet { - public static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) + public static Expression rewrite( + Expression expression, + Session session, + SymbolAllocator symbolAllocator, + PlannerContext plannerContext, + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { requireNonNull(plannerContext, "plannerContext is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + requireNonNull(expressionInterpreter, "expressionInterpreter is null"); if (expression instanceof SymbolReference) { return expression; } @@ -50,8 +57,7 @@ public static Expression rewrite(Expression expression, Session session, SymbolA expression = extractCommonPredicates(plannerContext.getMetadata(), expression); expression = normalizeOrExpression(expression); expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); - ExpressionInterpreter interpreter = new ExpressionInterpreter(expression, plannerContext, session, expressionTypes); - Object optimized = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object optimized = expressionInterpreter.optimize(expression, expressionTypes, NoOpSymbolResolver.INSTANCE); return new LiteralEncoder(plannerContext).toExpression(session, optimized, expressionTypes.get(NodeRef.of(expression))); } @@ -76,6 +82,12 @@ private static ExpressionRewriter createRewrite(PlannerContext plannerContext, T requireNonNull(plannerContext, "plannerContext is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), plannerContext, typeAnalyzer); + return (expression, context) -> rewrite( + expression, + context.getSession(), + context.getSymbolAllocator(), + plannerContext, + typeAnalyzer, + context.getExpressionInterpreter()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index e56546fb258a..75649f9f3b3d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -131,16 +131,18 @@ private static ExpressionRewriter createRewrite(PlannerContext plannerContext, T requireNonNull(plannerContext, "plannerContext is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - return (expression, context) -> unwrapCasts(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression); + return (expression, context) -> unwrapCasts(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression, context.getExpressionInterpreter()); } - public static Expression unwrapCasts(Session session, + public static Expression unwrapCasts( + Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types, - Expression expression) + Expression expression, + ExpressionInterpreter expressionInterpreter) { - return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, typeAnalyzer, session, types), expression); + return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, typeAnalyzer, session, types, expressionInterpreter), expression); } private static class Visitor @@ -150,15 +152,17 @@ private static class Visitor private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private final InterpretedFunctionInvoker functionInvoker; private final LiteralEncoder literalEncoder; - public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) + public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, ExpressionInterpreter expressionInterpreter) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager()); this.literalEncoder = new LiteralEncoder(plannerContext); } @@ -177,8 +181,10 @@ private Expression unwrapCast(ComparisonExpression expression) return expression; } - Object right = new ExpressionInterpreter(expression.getRight(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getRight())) - .optimize(NoOpSymbolResolver.INSTANCE); + Object right = expressionInterpreter.optimize( + expression.getRight(), + typeAnalyzer.getTypes(session, types, expression.getRight()), + NoOpSymbolResolver.INSTANCE); Cast cast = (Cast) expression.getLeft(); ComparisonExpression.Operator operator = expression.getOperator(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapTimestampToDateCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapTimestampToDateCastInComparison.java index f746e3d102b4..af9063db8507 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapTimestampToDateCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapTimestampToDateCastInComparison.java @@ -111,8 +111,8 @@ private Expression unwrapCast(ComparisonExpression expression) return expression; } - Object right = new ExpressionInterpreter(expression.getRight(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getRight())) - .optimize(NoOpSymbolResolver.INSTANCE); + Object right = new ExpressionInterpreter(plannerContext, session) + .optimize(expression.getRight(), typeAnalyzer.getTypes(session, types, expression.getRight()), NoOpSymbolResolver.INSTANCE); Cast cast = (Cast) expression.getLeft(); ComparisonExpression.Operator operator = expression.getOperator(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 9a6f3f454019..18447e314694 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -24,11 +24,11 @@ import io.trino.cost.CachingStatsProvider; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; -import io.trino.execution.warnings.WarningCollector; import io.trino.spi.connector.GroupingProperty; import io.trino.spi.connector.LocalProperty; import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -133,9 +133,15 @@ public AddExchanges(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, St } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator, session), PreferredProperties.any()); + PlanWithProperties result = plan.accept( + new Rewriter( + context.getIdAllocator(), + context.getSymbolAllocator(), + context.getSession(), + context.getExpressionInterpreter()), + PreferredProperties.any()); return result.getNode(); } @@ -147,19 +153,25 @@ private class Rewriter private final TypeProvider types; private final StatsProvider statsProvider; private final Session session; + private final ExpressionInterpreter expressionInterpreter; private final DomainTranslator domainTranslator; private final boolean distributedIndexJoins; private final boolean preferStreamingOperators; private final boolean redistributeWrites; private final boolean scaleWriters; - public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Rewriter( + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, + Session session, + ExpressionInterpreter expressionInterpreter) { this.idAllocator = idAllocator; this.symbolAllocator = symbolAllocator; this.types = symbolAllocator.getTypes(); this.statsProvider = new CachingStatsProvider(statsCalculator, session, types); this.session = session; + this.expressionInterpreter = expressionInterpreter; this.domainTranslator = new DomainTranslator(plannerContext); this.distributedIndexJoins = SystemSessionProperties.isDistributedIndexJoinEnabled(session); this.redistributeWrites = SystemSessionProperties.isRedistributeWrites(session); @@ -568,7 +580,8 @@ public PlanWithProperties visitFilter(FilterNode node, PreferredProperties prefe plannerContext, typeAnalyzer, statsProvider, - domainTranslator); + domainTranslator, + expressionInterpreter); if (plan.isPresent()) { return new PlanWithProperties(plan.get(), derivePropertiesRecursively(plan.get())); } @@ -1308,7 +1321,14 @@ private ActualProperties deriveProperties(PlanNode result, ActualProperties inpu private ActualProperties deriveProperties(PlanNode result, List inputProperties) { // TODO: move this logic to PlanSanityChecker once PropertyDerivations.deriveProperties fully supports local exchanges - ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, plannerContext, session, types, typeAnalyzer); + ActualProperties outputProperties = PropertyDerivations.deriveProperties( + result, + inputProperties, + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter); verify(result instanceof SemiJoinNode || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated) || outputProperties.isNullsAndAnyReplicated(), "SemiJoinNode is the only node that can strip null replication"); return outputProperties; @@ -1316,7 +1336,13 @@ private ActualProperties deriveProperties(PlanNode result, List inputProperties) { - return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, plannerContext, session, types, typeAnalyzer)); + return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties( + result, + inputProperties, + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index 82325676a935..4d80aa4c292a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -15,15 +15,11 @@ import io.trino.Session; import io.trino.cost.StatsAndCosts; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.BeginTableExecuteResult; -import io.trino.sql.planner.PlanNodeIdAllocator; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExchangeNode; @@ -83,15 +79,23 @@ public BeginTableWrite(Metadata metadata, FunctionManager functionManager) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { try { - return SimplePlanRewriter.rewriteWith(new Rewriter(session), plan, Optional.empty()); + return SimplePlanRewriter.rewriteWith(new Rewriter(context.getSession()), plan, Optional.empty()); } catch (RuntimeException e) { try { int nestLevel = 4; // so that it renders reasonably within exception stacktrace - String explain = textLogicalPlan(plan, types, metadata, functionManager, StatsAndCosts.empty(), session, nestLevel, false); + String explain = textLogicalPlan( + plan, + context.getSymbolAllocator().getTypes(), + metadata, + functionManager, + StatsAndCosts.empty(), + context.getSession(), + nestLevel, + false); e.addSuppressed(new Exception("Current plan:\n" + explain)); } catch (RuntimeException ignore) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java index 59c4c1697e49..8e832ccc0052 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java @@ -14,13 +14,8 @@ package io.trino.sql.planner.optimizations; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.spi.TrinoException; -import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.PlanNode; @@ -37,7 +32,7 @@ public class CheckSubqueryNodesAreRewritten implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { searchFrom(plan).where(ApplyNode.class::isInstance) .findFirst() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index 24b0b323bf22..afc1f9cc5dc6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -25,7 +25,6 @@ import com.google.common.collect.Multimap; import io.trino.Session; import io.trino.SystemSessionProperties; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.function.OperatorType; import io.trino.spi.type.StandardTypes; @@ -106,15 +105,17 @@ public HashGenerationOptimizer(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); - if (SystemSessionProperties.isOptimizeHashGenerationEnabled(session)) { - PlanWithProperties result = plan.accept(new Rewriter(session, metadata, idAllocator, symbolAllocator, types), new HashComputationSet()); + if (SystemSessionProperties.isOptimizeHashGenerationEnabled(context.getSession())) { + PlanWithProperties result = plan.accept( + new Rewriter( + context.getSession(), + metadata, + context.getIdAllocator(), + context.getSymbolAllocator(), + context.getSymbolAllocator().getTypes()), + new HashComputationSet()); return result.getNode(); } return plan; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index ec72bf630aa6..a1d76d89ad55 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -21,7 +21,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.BoundSignature; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.ResolvedIndex; @@ -29,10 +28,10 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; @@ -80,14 +79,17 @@ public IndexJoinOptimizer(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider type, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); - - return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator, plannerContext, session), plan, null); + return SimplePlanRewriter.rewriteWith( + new Rewriter( + context.getSymbolAllocator(), + context.getIdAllocator(), + plannerContext, + context.getSession(), + context.getExpressionInterpreter()), + plan, + null); } private static class Rewriter @@ -97,17 +99,20 @@ private static class Rewriter private final PlanNodeIdAllocator idAllocator; private final PlannerContext plannerContext; private final Session session; + private final ExpressionInterpreter expressionInterpreter; private Rewriter( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, - Session session) + Session session, + ExpressionInterpreter expressionInterpreter) { this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.session = requireNonNull(session, "session is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); } @Override @@ -126,7 +131,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) symbolAllocator, idAllocator, plannerContext, - session); + session, + expressionInterpreter); if (leftIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key Map trace = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinSymbols)); @@ -139,7 +145,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) symbolAllocator, idAllocator, plannerContext, - session); + session, + expressionInterpreter); if (rightIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key Map trace = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinSymbols)); @@ -250,18 +257,21 @@ private static class IndexSourceRewriter private final PlannerContext plannerContext; private final DomainTranslator domainTranslator; private final Session session; + private final ExpressionInterpreter expressionInterpreter; private IndexSourceRewriter( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, - Session session) + Session session, + ExpressionInterpreter expressionInterpreter) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.domainTranslator = new DomainTranslator(plannerContext); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.session = requireNonNull(session, "session is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); } public static Optional rewriteWithIndex( @@ -270,10 +280,16 @@ public static Optional rewriteWithIndex( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, - Session session) + Session session, + ExpressionInterpreter interpreter) { AtomicBoolean success = new AtomicBoolean(); - IndexSourceRewriter indexSourceRewriter = new IndexSourceRewriter(symbolAllocator, idAllocator, plannerContext, session); + IndexSourceRewriter indexSourceRewriter = new IndexSourceRewriter( + symbolAllocator, + idAllocator, + plannerContext, + session, + interpreter); PlanNode rewritten = SimplePlanRewriter.rewriteWith(indexSourceRewriter, planNode, new Context(lookupSymbols, success)); if (success.get()) { return Optional.of(rewritten); @@ -300,7 +316,8 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context plannerContext, session, predicate, - symbolAllocator.getTypes()); + symbolAllocator.getTypes(), + expressionInterpreter); TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() .transformKeys(node.getAssignments()::get) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java index 10da7a2f3ef9..c93d96089c3b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java @@ -14,11 +14,7 @@ package io.trino.sql.planner.optimizations; import com.google.common.collect.ImmutableList; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.PlanNodeIdAllocator; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.LimitNode; @@ -43,15 +39,10 @@ public class LimitPushDown implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); - - return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator), plan, null); + return SimplePlanRewriter.rewriteWith(new Rewriter(context.getIdAllocator()), plan, null); } private static class LimitContext diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java index ce64684bad76..12436700720a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -19,7 +19,6 @@ import com.google.common.collect.Iterables; import io.trino.Session; import io.trino.SystemSessionProperties; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.TableProperties; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -31,8 +30,6 @@ import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.FilterNode; @@ -73,12 +70,15 @@ public MetadataQueryOptimizer(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - if (!SystemSessionProperties.isOptimizeMetadataQueries(session)) { + if (!SystemSessionProperties.isOptimizeMetadataQueries(context.getSession())) { return plan; } - return SimplePlanRewriter.rewriteWith(new Optimizer(session, plannerContext, idAllocator), plan, null); + return SimplePlanRewriter.rewriteWith( + new Optimizer(context.getSession(), plannerContext, context.getIdAllocator()), + plan, + null); } private static class Optimizer diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 35a2a9ce67e4..752bd3c7254f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -18,13 +18,11 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.Type; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; @@ -83,10 +81,13 @@ public OptimizeMixedDistinctAggregations(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - if (isOptimizeDistinctAggregationEnabled(session)) { - return SimplePlanRewriter.rewriteWith(new Optimizer(session, idAllocator, symbolAllocator, metadata), plan, Optional.empty()); + if (isOptimizeDistinctAggregationEnabled(context.getSession())) { + return SimplePlanRewriter.rewriteWith( + new Optimizer(context.getSession(), context.getIdAllocator(), context.getSymbolAllocator(), metadata), + plan, + Optional.empty()); } return plan; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java index 1e018beb8866..07d9f16250ac 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java @@ -15,18 +15,25 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.PlanNode; public interface PlanOptimizer { - PlanNode optimize( - PlanNode plan, - Session session, - TypeProvider types, - SymbolAllocator symbolAllocator, - PlanNodeIdAllocator idAllocator, - WarningCollector warningCollector); + PlanNode optimize(PlanNode plan, Context context); + + interface Context + { + Session getSession(); + + SymbolAllocator getSymbolAllocator(); + + PlanNodeIdAllocator getIdAllocator(); + + WarningCollector getWarningCollector(); + + ExpressionInterpreter getExpressionInterpreter(); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 130577a49d44..0ce82e4a4cbd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -21,7 +21,6 @@ import com.google.common.collect.Sets; import com.google.common.collect.Streams; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; @@ -143,15 +142,21 @@ public PredicatePushDown( } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith( - new Rewriter(symbolAllocator, idAllocator, plannerContext, typeAnalyzer, session, types, useTableProperties, dynamicFiltering), + new Rewriter( + context.getSymbolAllocator(), + context.getIdAllocator(), + plannerContext, + typeAnalyzer, + context.getSession(), + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter(), + useTableProperties, + dynamicFiltering), plan, TRUE_LITERAL); } @@ -166,6 +171,7 @@ private static class Rewriter private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private final ExpressionEquivalence expressionEquivalence; private final boolean dynamicFiltering; private final LiteralEncoder literalEncoder; @@ -178,6 +184,7 @@ private Rewriter( TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, + ExpressionInterpreter expressionInterpreter, boolean useTableProperties, boolean dynamicFiltering) { @@ -188,6 +195,7 @@ private Rewriter( this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); this.expressionEquivalence = new ExpressionEquivalence(plannerContext.getMetadata(), plannerContext.getFunctionManager(), typeAnalyzer); this.dynamicFiltering = dynamicFiltering; @@ -296,7 +304,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) .map(conjunct -> canonicalizeExpression(conjunct, typeAnalyzer.getTypes(session, types, conjunct), plannerContext, session)) // normalize expressions to a form that unwrapCasts understands - .map(conjunct -> unwrapCasts(session, plannerContext, typeAnalyzer, types, conjunct)) + .map(conjunct -> unwrapCasts(session, plannerContext, typeAnalyzer, types, conjunct, expressionInterpreter)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, inlinedDeterministicConjuncts)); @@ -426,8 +434,18 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // See if we can rewrite outer joins in terms of a plain inner join node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate); - Expression leftEffectivePredicate = effectivePredicateExtractor.extract(session, node.getLeft(), types, typeAnalyzer); - Expression rightEffectivePredicate = effectivePredicateExtractor.extract(session, node.getRight(), types, typeAnalyzer); + Expression leftEffectivePredicate = effectivePredicateExtractor.extract( + session, + node.getLeft(), + types, + typeAnalyzer, + expressionInterpreter); + Expression rightEffectivePredicate = effectivePredicateExtractor.extract( + session, + node.getRight(), + types, + typeAnalyzer, + expressionInterpreter); Expression joinPredicate = extractJoinPredicate(node); Expression leftPredicate; @@ -748,8 +766,18 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext innerSymbolsForOuterJoin, Ex private Expression simplifyExpression(Expression expression) { Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); - ExpressionInterpreter optimizer = new ExpressionInterpreter(expression, plannerContext, session, expressionTypes); - return literalEncoder.toExpression(session, optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); + return literalEncoder.toExpression( + session, + expressionInterpreter.optimize(expression, expressionTypes, NoOpSymbolResolver.INSTANCE), + expressionTypes.get(NodeRef.of(expression))); } private boolean areExpressionsEquivalent(Expression leftExpression, Expression rightExpression) @@ -1236,8 +1266,8 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r private Object nullInputEvaluator(Collection nullSymbols, Expression expression) { Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); - return new ExpressionInterpreter(expression, plannerContext, session, expressionTypes) - .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); + return expressionInterpreter + .optimize(expression, expressionTypes, symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } private boolean joinEqualityExpression(Expression expression, Collection leftSymbols, Collection rightSymbols) @@ -1359,8 +1389,12 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, plannerContext, session, types, typeAnalyzer)) + .map(source -> derivePropertiesRecursively(source, plannerContext, session, types, typeAnalyzer, expressionInterpreter)) .collect(toImmutableList()); - return deriveProperties(node, inputProperties, plannerContext, session, types, typeAnalyzer); + return deriveProperties(node, inputProperties, plannerContext, session, types, typeAnalyzer, expressionInterpreter); } public static ActualProperties deriveProperties( @@ -136,9 +137,12 @@ public static ActualProperties deriveProperties( PlannerContext plannerContext, Session session, TypeProvider types, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { - ActualProperties output = node.accept(new Visitor(plannerContext, session, types, typeAnalyzer), inputProperties); + ActualProperties output = node.accept( + new Visitor(plannerContext, session, types, typeAnalyzer, expressionInterpreter), + inputProperties); output.getNodePartitioning().ifPresent(partitioning -> verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); @@ -159,9 +163,10 @@ public static ActualProperties streamBackdoorDeriveProperties( PlannerContext plannerContext, Session session, TypeProvider types, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { - return node.accept(new Visitor(plannerContext, session, types, typeAnalyzer), inputProperties); + return node.accept(new Visitor(plannerContext, session, types, typeAnalyzer, expressionInterpreter), inputProperties); } private static class Visitor @@ -171,13 +176,20 @@ private static class Visitor private final Session session; private final TypeProvider types; private final TypeAnalyzer typeAnalyzer; + private final ExpressionInterpreter expressionInterpreter; - public Visitor(PlannerContext plannerContext, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) + public Visitor( + PlannerContext plannerContext, + Session session, + TypeProvider types, + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { this.plannerContext = plannerContext; this.session = session; this.types = types; this.typeAnalyzer = typeAnalyzer; + this.expressionInterpreter = expressionInterpreter; } @Override @@ -709,7 +721,8 @@ public ActualProperties visitFilter(FilterNode node, List inpu plannerContext, session, node.getPredicate(), - types); + types, + expressionInterpreter); Map constants = new HashMap<>(properties.getConstants()); constants.putAll(extractFixedValues(decomposedPredicate.getTupleDomain()).orElse(ImmutableMap.of())); @@ -735,13 +748,12 @@ public ActualProperties visitProject(ProjectNode node, List in Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); - ExpressionInterpreter optimizer = new ExpressionInterpreter(expression, plannerContext, session, expressionTypes); // TODO: // We want to use a symbol resolver that looks up in the constants from the input subplan // to take advantage of constant-folding for complex expressions // However, that currently causes errors when those expressions operate on arrays or row types // ("ROW comparison not supported for fields with null elements", etc) - Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); + Object value = expressionInterpreter.optimize(expression, expressionTypes, NoOpSymbolResolver.INSTANCE); if (value instanceof SymbolReference) { Symbol symbol = Symbol.from((SymbolReference) value); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java index ae9fe64fecc4..58fcf6d1a5f8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java @@ -13,11 +13,6 @@ */ package io.trino.sql.planner.optimizations; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; -import io.trino.sql.planner.PlanNodeIdAllocator; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.SemiJoinNode; @@ -30,7 +25,7 @@ public class ReplicateSemiJoinInDelete implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); return SimplePlanRewriter.rewriteWith(new Rewriter(), plan); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java index 3cc822fa3b87..105345241031 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java @@ -13,12 +13,7 @@ */ package io.trino.sql.planner.optimizations; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.OptimizerStatsRecorder; -import io.trino.sql.planner.PlanNodeIdAllocator; -import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.PlanNode; import static java.util.Objects.requireNonNull; @@ -37,19 +32,13 @@ public StatsRecordingPlanOptimizer(OptimizerStatsRecorder stats, PlanOptimizer d } @Override - public PlanNode optimize( - PlanNode plan, - Session session, - TypeProvider types, - SymbolAllocator symbolAllocator, - PlanNodeIdAllocator idAllocator, - WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { PlanNode result; long duration; try { long start = System.nanoTime(); - result = delegate.optimize(plan, session, types, symbolAllocator, idAllocator, warningCollector); + result = delegate.optimize(plan, context); duration = System.nanoTime() - start; } catch (RuntimeException e) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 01b89d74aaa5..c882ad13d069 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -25,6 +25,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.LocalProperty; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Partitioning.ArgumentBinding; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; @@ -110,12 +111,13 @@ public static StreamProperties derivePropertiesRecursively( PlannerContext plannerContext, Session session, TypeProvider types, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, plannerContext, session, types, typeAnalyzer)) + .map(source -> derivePropertiesRecursively(source, plannerContext, session, types, typeAnalyzer, expressionInterpreter)) .collect(toImmutableList()); - return deriveProperties(node, inputProperties, plannerContext, session, types, typeAnalyzer); + return deriveProperties(node, inputProperties, plannerContext, session, types, typeAnalyzer, expressionInterpreter); } public static StreamProperties deriveProperties( @@ -124,9 +126,10 @@ public static StreamProperties deriveProperties( PlannerContext plannerContext, Session session, TypeProvider types, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { - return deriveProperties(node, ImmutableList.of(inputProperties), plannerContext, session, types, typeAnalyzer); + return deriveProperties(node, ImmutableList.of(inputProperties), plannerContext, session, types, typeAnalyzer, expressionInterpreter); } public static StreamProperties deriveProperties( @@ -135,7 +138,8 @@ public static StreamProperties deriveProperties( PlannerContext plannerContext, Session session, TypeProvider types, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + ExpressionInterpreter expressionInterpreter) { requireNonNull(node, "node is null"); requireNonNull(inputProperties, "inputProperties is null"); @@ -143,6 +147,7 @@ public static StreamProperties deriveProperties( requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + requireNonNull(expressionInterpreter, "expressionInterpreter is null"); // properties.otherActualProperties will never be null here because the only way // an external caller should obtain StreamProperties is from this method, and the @@ -155,7 +160,8 @@ public static StreamProperties deriveProperties( plannerContext, session, types, - typeAnalyzer); + typeAnalyzer, + expressionInterpreter); StreamProperties result = node.accept(new Visitor(plannerContext.getMetadata(), session), inputProperties) .withOtherActualProperties(otherProperties); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 69d38052ee7c..8791d1e4b3b2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; @@ -81,9 +80,17 @@ public TransformQuantifiedComparisonApplyToCorrelatedJoin(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { - return rewriteWith(new Rewriter(idAllocator, types, symbolAllocator, metadata, session), plan, null); + return rewriteWith( + new Rewriter( + context.getIdAllocator(), + context.getSymbolAllocator().getTypes(), + context.getSymbolAllocator(), + metadata, + context.getSession()), + plan, + null); } private static class Rewriter diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index bc97d123799d..3e6fa4e0d3c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -19,9 +19,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; -import io.trino.Session; import io.trino.cost.PlanNodeStatsEstimate; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.connector.ColumnHandle; import io.trino.sql.DynamicFilters; @@ -29,10 +27,8 @@ import io.trino.sql.planner.NodeAndMappings; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.PartitioningScheme; -import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; -import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; @@ -134,13 +130,9 @@ public UnaliasSymbolReferences(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); Visitor visitor = new Visitor(metadata, SymbolMapper::symbolMapper); PlanAndMappings result = plan.accept(visitor, UnaliasContext.empty()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java index 43189b43b5a0..84956b66c39b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionId; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; @@ -24,9 +23,9 @@ import io.trino.sql.ExpressionUtils; import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; @@ -67,15 +66,18 @@ public WindowFilterPushDown(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); - - return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, plannerContext, session, types), plan, null); + return SimplePlanRewriter.rewriteWith( + new Rewriter( + context.getIdAllocator(), + plannerContext, + context.getSession(), + context.getSymbolAllocator().getTypes(), + context.getExpressionInterpreter()), + plan, + null); } private static class Rewriter @@ -85,6 +87,7 @@ private static class Rewriter private final PlannerContext plannerContext; private final Session session; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private final FunctionId rowNumberFunctionId; private final FunctionId rankFunctionId; private final DomainTranslator domainTranslator; @@ -93,12 +96,14 @@ private Rewriter( PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, Session session, - TypeProvider types) + TypeProvider types, + ExpressionInterpreter expressionInterpreter) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = requireNonNull(expressionInterpreter, "expressionInterpreter is null"); rowNumberFunctionId = plannerContext.getMetadata().resolveFunction(session, QualifiedName.of("row_number"), ImmutableList.of()).getFunctionId(); rankFunctionId = plannerContext.getMetadata().resolveFunction(session, QualifiedName.of("rank"), ImmutableList.of()).getFunctionId(); domainTranslator = new DomainTranslator(plannerContext); @@ -166,7 +171,12 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - TupleDomain tupleDomain = DomainTranslator.getExtractionResult(plannerContext, session, node.getPredicate(), types).getTupleDomain(); + TupleDomain tupleDomain = DomainTranslator.getExtractionResult( + plannerContext, + session, + node.getPredicate(), + types, + expressionInterpreter).getTupleDomain(); if (source instanceof RowNumberNode) { Symbol rowNumberSymbol = ((RowNumberNode) source).getRowNumberSymbol(); @@ -201,7 +211,12 @@ else if (source instanceof WindowNode && isOptimizeTopNRanking(session)) { private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Symbol rankingSymbol, int upperBound) { - ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, session, filterNode.getPredicate(), types); + ExtractionResult extractionResult = DomainTranslator.getExtractionResult( + plannerContext, + session, + filterNode.getPredicate(), + types, + expressionInterpreter); TupleDomain tupleDomain = extractionResult.getTupleDomain(); if (!allRankingValuesInDomain(tupleDomain, rankingSymbol, upperBound)) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java index 26cff3e6c597..92af5be8c76e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java @@ -31,6 +31,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Plan; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.PlanNode; @@ -619,7 +620,8 @@ public Void visitFilter(FilterNode node, IoPlanBuilder context) plannerContext, session, node.getPredicate(), - plan.getTypes()); + plan.getTypes(), + new ExpressionInterpreter(plannerContext, session)); TupleDomain filterDomain = decomposedPredicate.getTupleDomain() .transformKeys(tableScanNode.getAssignments()::get); addInputTableConstraints(filterDomain, tableScanNode, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java index 263b0e3a4934..3ea85e7f9370 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java @@ -16,6 +16,7 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.optimizations.ActualProperties; @@ -77,6 +78,7 @@ private class Visitor final PlannerContext plannerContext; final TypeAnalyzer typeAnalyzer; final TypeProvider types; + final ExpressionInterpreter expressionInterpreter; Visitor(Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types) { @@ -84,6 +86,7 @@ private class Visitor this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); + this.expressionInterpreter = new ExpressionInterpreter(plannerContext, session); } @Override @@ -120,14 +123,26 @@ public Optional visitAggregation(AggregationNode node, Void conte // No remote repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed on a single node. - ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, plannerContext, session, types, typeAnalyzer); + ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively( + node, + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter); checkArgument(forceSingleNode || globalProperties.isSingleNode(), "Final aggregation with default value not separated from partial aggregation by remote hash exchange"); if (!seenExchanges.localRepartitionExchange) { // No local repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed by single thread. - StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, plannerContext, session, types, typeAnalyzer); + StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively( + node, + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter); checkArgument(localProperties.isSingleStream(), "Final aggregation with default value not separated from partial aggregation by local hash exchange"); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateLimitWithPresortedInput.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateLimitWithPresortedInput.java index 66b852155ef9..c91d9825b511 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateLimitWithPresortedInput.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateLimitWithPresortedInput.java @@ -21,6 +21,7 @@ import io.trino.spi.connector.LocalProperty; import io.trino.spi.connector.SortingProperty; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; @@ -62,6 +63,7 @@ private static final class Visitor private final PlannerContext plannerContext; private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private Visitor(Session session, PlannerContext plannerContext, @@ -72,6 +74,7 @@ private Visitor(Session session, this.plannerContext = plannerContext; this.typeAnalyzer = typeAnalyzer; this.types = types; + this.expressionInterpreter = new ExpressionInterpreter(plannerContext, session); } @Override @@ -91,7 +94,13 @@ public Void visitLimit(LimitNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), plannerContext, session, types, typeAnalyzer); + StreamProperties properties = derivePropertiesRecursively( + node.getSource(), + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter); PeekingIterator> actuals = peekingIterator(normalizeAndPrune(properties.getLocalProperties()).iterator()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateStreamingAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateStreamingAggregations.java index 555bd27278f8..9d0ef45adadc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateStreamingAggregations.java @@ -20,6 +20,7 @@ import io.trino.spi.connector.GroupingProperty; import io.trino.spi.connector.LocalProperty; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; @@ -61,6 +62,7 @@ private static final class Visitor private final PlannerContext plannerContext; private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; + private final ExpressionInterpreter expressionInterpreter; private Visitor(Session session, PlannerContext plannerContext, @@ -71,6 +73,7 @@ private Visitor(Session session, this.plannerContext = plannerContext; this.typeAnalyzer = typeAnalyzer; this.types = types; + this.expressionInterpreter = new ExpressionInterpreter(plannerContext, session); } @Override @@ -87,7 +90,13 @@ public Void visitAggregation(AggregationNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), plannerContext, session, types, typeAnalyzer); + StreamProperties properties = derivePropertiesRecursively( + node.getSource(), + plannerContext, + session, + types, + typeAnalyzer, + expressionInterpreter); List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedSymbols())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java index fe2a32b4e9e8..fab2c98efca1 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java @@ -762,9 +762,9 @@ protected Void visitSymbolReference(SymbolReference node, Void context) private Object interpret(Expression expression, Type expectedType, Session session) { Map, Type> expressionTypes = getTypes(session, getPlannerContext(), INPUT_TYPES, expression); - ExpressionInterpreter evaluator = new ExpressionInterpreter(expression, runner.getPlannerContext(), session, expressionTypes); + ExpressionInterpreter evaluator = new ExpressionInterpreter(runner.getPlannerContext(), session); - Object result = evaluator.evaluate(symbol -> { + Object result = evaluator.evaluate(expression, expressionTypes, symbol -> { int position = 0; int channel = INPUT_MAPPING.get(symbol); Type type = INPUT_TYPES.get(symbol); diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index a7c8104076d4..09428153ec68 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -14,6 +14,7 @@ package io.trino.sql; import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -30,8 +31,12 @@ import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.assertions.SymbolAliases; import io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.LikePredicate; +import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; @@ -71,6 +76,7 @@ import static io.trino.sql.ExpressionTestUtils.resolveFunctionCalls; import static io.trino.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.trino.sql.ParsingUtil.createParsingOptions; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; @@ -81,6 +87,8 @@ import static java.util.function.Function.identity; import static org.joda.time.DateTimeZone.UTC; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotSame; +import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; public class TestExpressionInterpreter @@ -520,6 +528,50 @@ public void testIn() assertOptimizedEquals("0 / 0 in (2, 2)", "0 / 0 = 2"); } + @Test + public void testDoesNotOptimizeInPredicateWithLiteralsInList() + { + // should not create a new instance of InPredicate when IN list consists of literals only + InListExpression inList = new InListExpression(ImmutableList.of(new LongLiteral("42"), new LongLiteral("43"))); + InPredicate inPredicate = new InPredicate(new SymbolReference("unbound_integer"), inList); + assertSame(optimize(inPredicate), inPredicate); + } + + @Test + public void testDoesNotOptimizeUnboundedInPredicate() + { + // should not create a new instance of InPredicate when it contains unbounded expressions that cannot be simplified + InListExpression inList = new InListExpression(ImmutableList.of( + new SymbolReference("unbound_integer"), + new Cast(new SymbolReference("unbound_long"), toSqlType(INTEGER)))); + InPredicate inPredicate = new InPredicate(new SymbolReference("unbound_integer"), inList); + assertSame(optimize(inPredicate), inPredicate); + } + + @Test + public void testDoesNotOptimizeInListExpression() + { + // should not create a new instance of InListExpression since only value changed + InListExpression inList = new InListExpression(ImmutableList.of(new LongLiteral("42"), new LongLiteral("43"))); + InPredicate inPredicate = new InPredicate(planExpression("3 * 2 * unbound_integer"), inList); + InPredicate optimizedInPredicate = (InPredicate) optimize(inPredicate); + assertNotSame(inPredicate, optimizedInPredicate); + assertSame(inPredicate.getValueList(), inList); + } + + @Test + public void testDoesNotOptimizeInPredicateValueExpression() + { + // should not create a new instance of value since only InListExpression changed + Expression value = new SymbolReference("unbound_integer"); + InPredicate inPredicate = new InPredicate(value, new InListExpression(ImmutableList.of( + new LongLiteral("42"), + planExpression("2 * 3")))); + InPredicate optimizedInPredicate = (InPredicate) optimize(inPredicate); + assertNotSame(inPredicate, optimizedInPredicate); + assertSame(optimizedInPredicate.getValue(), value); + } + @Test public void testInComplexTypes() { @@ -1926,8 +1978,8 @@ private static Object optimize(@Language("SQL") String expression) static Object optimize(Expression parsedExpression) { Map, Type> expressionTypes = getTypes(TEST_SESSION, PLANNER_CONTEXT, SYMBOL_TYPES, parsedExpression); - ExpressionInterpreter interpreter = new ExpressionInterpreter(parsedExpression, PLANNER_CONTEXT, TEST_SESSION, expressionTypes); - return interpreter.optimize(INPUTS); + ExpressionInterpreter interpreter = new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION); + return interpreter.optimize(parsedExpression, expressionTypes, INPUTS); } // TODO replace that method with io.trino.sql.ExpressionTestUtils.planExpression @@ -1974,8 +2026,8 @@ private static void assertRoundTrip(String expression) private static Object evaluate(Expression expression) { Map, Type> expressionTypes = getTypes(TEST_SESSION, PLANNER_CONTEXT, SYMBOL_TYPES, expression); - ExpressionInterpreter interpreter = new ExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION, expressionTypes); + ExpressionInterpreter interpreter = new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION); - return interpreter.evaluate(INPUTS); + return interpreter.evaluate(expression, expressionTypes, INPUTS); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index 6061d83908c6..43e5d2740309 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -100,8 +100,8 @@ private Expression simplifyExpression(Expression expression) // Testing simplified expressions is important, since simplification may create CASTs or function calls that cannot be simplified by the ExpressionOptimizer Map, Type> expressionTypes = getExpressionTypes(expression); - ExpressionInterpreter interpreter = new ExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION, expressionTypes); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + ExpressionInterpreter interpreter = new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION); + Object value = interpreter.optimize(expression, expressionTypes, NoOpSymbolResolver.INSTANCE); return literalEncoder.toExpression(TEST_SESSION, value, expressionTypes.get(NodeRef.of(expression))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java index 5fce15968c26..d08a022ea80e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java @@ -1929,7 +1929,12 @@ private ExtractionResult fromPredicate(Expression originalPredicate) return transaction(new TestingTransactionManager(), new AllowAllAccessControl()) .singleStatement() .execute(TEST_SESSION, transactionSession -> { - return DomainTranslator.getExtractionResult(functionResolution.getPlannerContext(), transactionSession, originalPredicate, TYPES); + return DomainTranslator.getExtractionResult( + functionResolution.getPlannerContext(), + transactionSession, + originalPredicate, + TYPES, + new ExpressionInterpreter(functionResolution.getPlannerContext(), transactionSession)); }); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index ff39b5360640..f1d818004e88 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -174,6 +174,7 @@ public TableProperties getTableProperties(Session session, TableHandle handle) private final PlannerContext plannerContext = plannerContextBuilder().withMetadata(metadata).build(); private final TypeAnalyzer typeAnalyzer = createTestingTypeAnalyzer(plannerContext); + private final ExpressionInterpreter expressionInterpreter = new ExpressionInterpreter(plannerContext, SESSION); private final EffectivePredicateExtractor effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(plannerContext), plannerContext, true); private final EffectivePredicateExtractor effectivePredicateExtractorWithoutTableProperties = new EffectivePredicateExtractor(new DomainTranslator(plannerContext), plannerContext, false); @@ -242,7 +243,12 @@ D, new Aggregation( Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Rewrite in terms of group by symbols assertEquals( @@ -267,7 +273,12 @@ public void testGroupByEmpty() Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, TRUE_LITERAL); } @@ -285,7 +296,12 @@ public void testFilter() .build()), lessThan(BE, bigintLiteral(10)))); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Non-deterministic functions should be purged assertEquals( @@ -306,7 +322,12 @@ public void testProject() lessThan(CE, bigintLiteral(10)))), Assignments.of(D, AE, E, CE)); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Rewrite in terms of project output symbols assertEquals( @@ -331,7 +352,12 @@ public void testProjectWithSymbolReuse() lessThan(CE, bigintLiteral(10)))), Assignments.of(D, AE, B, CE)); - Expression effectivePredicateWhenBReused = effectivePredicateExtractor.extract(SESSION, projectReusingB, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicateWhenBReused = effectivePredicateExtractor.extract( + SESSION, + projectReusingB, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals( normalizeConjuncts(effectivePredicateWhenBReused), @@ -354,7 +380,12 @@ public void testProjectWithSymbolReuse() .put(F, BE) .build()); - Expression effectivePredicateWhenCReused = effectivePredicateExtractor.extract(SESSION, projectReusingC, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicateWhenCReused = effectivePredicateExtractor.extract( + SESSION, + projectReusingC, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals( normalizeConjuncts(effectivePredicateWhenCReused), @@ -375,7 +406,12 @@ public void testTopN() 1, new OrderingScheme(ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST)), TopNNode.Step.PARTIAL); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Pass through assertEquals( @@ -400,7 +436,12 @@ public void testLimit() 1, false); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Pass through assertEquals( @@ -425,7 +466,12 @@ public void testSort() new OrderingScheme(ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST)), false); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Pass through assertEquals( @@ -457,7 +503,12 @@ public void testWindow() ImmutableSet.of(), 0); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Pass through assertEquals( @@ -480,7 +531,12 @@ public void testTableScan() assignments, false, Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL); node = new TableScanNode( @@ -492,7 +548,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, FALSE_LITERAL); TupleDomain predicate = TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))); @@ -505,7 +566,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(1L), AE))); predicate = TupleDomain.withColumnDomains(ImmutableMap.of( @@ -520,7 +586,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractorWithoutTableProperties.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractorWithoutTableProperties.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BE), equals(bigintLiteral(1L), AE))); node = new TableScanNode( @@ -532,7 +603,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, and(equals(AE, bigintLiteral(1)), equals(BE, bigintLiteral(2)))); node = new TableScanNode( @@ -546,7 +622,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BE), equals(bigintLiteral(1L), AE))); node = new TableScanNode( @@ -558,7 +639,12 @@ public void testTableScan() Optional.empty(), false, Optional.empty()); - effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL); } @@ -583,7 +669,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(bigintLiteral(2))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2))))); // one column with null @@ -598,7 +685,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(2))), new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), or( new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))), new IsNullPredicate(AE))); @@ -612,7 +700,8 @@ public void testValues() ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new IsNullPredicate(AE)); // nested row @@ -624,7 +713,8 @@ public void testValues() ImmutableList.of(R), ImmutableList.of(new Row(ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new NullLiteral())))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), TRUE_LITERAL); // many rows @@ -641,7 +731,8 @@ public void testValues() ImmutableList.of(A), rows), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new BetweenPredicate(AE, bigintLiteral(0), bigintLiteral(499))); // NaN @@ -653,7 +744,8 @@ public void testValues() ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new NotExpression(new IsNullPredicate(DE))); // NaN and NULL @@ -667,7 +759,8 @@ public void testValues() new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(DOUBLE)))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), TRUE_LITERAL); // NaN and value @@ -681,7 +774,8 @@ public void testValues() new Row(ImmutableList.of(doubleLiteral(42.))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new NotExpression(new IsNullPredicate(DE))); // Real NaN @@ -693,7 +787,8 @@ public void testValues() ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(new Cast(doubleLiteral(Double.NaN), toSqlType(REAL)))))), TypeProvider.copyOf(ImmutableMap.of(D, REAL)), - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), new NotExpression(new IsNullPredicate(DE))); // multiple columns @@ -707,7 +802,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1), bigintLiteral(100))), new Row(ImmutableList.of(bigintLiteral(2), bigintLiteral(200))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), and( new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))), new InPredicate(BE, new InListExpression(ImmutableList.of(bigintLiteral(100), bigintLiteral(200)))))); @@ -723,7 +819,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1), new Cast(new NullLiteral(), toSqlType(BIGINT)))), new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)), bigintLiteral(200))))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), and( or(new ComparisonExpression(EQUAL, AE, bigintLiteral(1)), new IsNullPredicate(AE)), or(new ComparisonExpression(EQUAL, BE, bigintLiteral(200)), new IsNullPredicate(BE)))); @@ -747,7 +844,8 @@ public void testValues() new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(BE)))), types, - typeAnalyzer), + typeAnalyzer, + expressionInterpreter), TRUE_LITERAL); } @@ -756,7 +854,12 @@ private Expression extract(TypeProvider types, PlanNode node) return transaction(new TestingTransactionManager(), new AllowAllAccessControl()) .singleStatement() .execute(SESSION, transactionSession -> { - return effectivePredicateExtractor.extract(transactionSession, node, types, typeAnalyzer); + return effectivePredicateExtractor.extract( + transactionSession, + node, + types, + typeAnalyzer, + expressionInterpreter); }); } @@ -773,7 +876,12 @@ public void testUnion() symbolMapping, ImmutableList.copyOf(symbolMapping.keySet())); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Only the common conjuncts can be inferred through a Union assertEquals( @@ -824,7 +932,12 @@ public void testInnerJoin() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // All predicates having output symbol should be carried through assertEquals( @@ -868,7 +981,12 @@ public void testInnerJoinPropagatesPredicatesViaEquiConditions() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals( normalizeConjuncts(effectivePredicate), @@ -901,7 +1019,12 @@ public void testInnerJoinWithFalseFilter() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); assertEquals(effectivePredicate, FALSE_LITERAL); } @@ -948,7 +1071,12 @@ public void testLeftJoin() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // All right side symbols having output symbols should be checked against NULL assertEquals( @@ -997,7 +1125,12 @@ public void testLeftJoinWithFalseInner() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // False literal on the right side should be ignored assertEquals( @@ -1050,7 +1183,12 @@ public void testRightJoin() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // All left side symbols should be checked against NULL assertEquals( @@ -1098,7 +1236,12 @@ public void testRightJoinWithFalseInner() ImmutableMap.of(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // False literal on the left side should be ignored assertEquals( @@ -1122,7 +1265,12 @@ public void testSemiJoin() Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); + Expression effectivePredicate = effectivePredicateExtractor.extract( + SESSION, + node, + TypeProvider.empty(), + typeAnalyzer, + expressionInterpreter); // Currently, only pull predicates through the source plan assertEquals( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java index 3f89830e100e..ff2c4b860451 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java @@ -289,7 +289,8 @@ private void assertRoundTrip(T value, Type type, BiPredicate predicate private Object getExpressionValue(Expression expression) { - return new ExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION, getExpressionTypes(expression)).evaluate(); + return new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION) + .evaluate(expression, getExpressionTypes(expression)); } private Type getExpressionType(Expression expression) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java index 7f15d7d8d824..9a1ff9e57005 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java @@ -20,10 +20,13 @@ import io.trino.connector.CatalogName; import io.trino.execution.warnings.WarningCollector; import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LogicalPlanner; import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.iterative.IterativeOptimizer; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; @@ -226,4 +229,44 @@ protected SubPlan subplan(@Language("SQL") String sql, LogicalPlanner.Stage stag throw new AssertionError("Planning failed for SQL: " + sql, e); } } + + protected PlanOptimizer.Context createOptimizerContext( + Session session, + SymbolAllocator symbolAllocator, + PlanNodeIdAllocator idAllocator, + ExpressionInterpreter interpreter) + { + return new PlanOptimizer.Context() + { + @Override + public Session getSession() + { + return session; + } + + @Override + public SymbolAllocator getSymbolAllocator() + { + return symbolAllocator; + } + + @Override + public PlanNodeIdAllocator getIdAllocator() + { + return idAllocator; + } + + @Override + public WarningCollector getWarningCollector() + { + return WarningCollector.NOOP; + } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return interpreter; + } + }; + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java index 7a42a90b894d..ac5703f54887 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -24,6 +24,7 @@ import io.trino.cost.PlanCostEstimate; import io.trino.cost.StatsProvider; import io.trino.execution.warnings.WarningCollector; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -124,6 +125,9 @@ private Rule.Context createContext() Optional.empty(), queryRunner.getDefaultSession(), symbolAllocator.getTypes()); + ExpressionInterpreter expressionInterpreter = new ExpressionInterpreter( + queryRunner.getPlannerContext(), + queryRunner.getDefaultSession()); return new Rule.Context() { @@ -171,6 +175,12 @@ public WarningCollector getWarningCollector() { return WarningCollector.NOOP; } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return expressionInterpreter; + } }; } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java index ff44fa6a41c9..a79249bb5770 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -16,6 +16,7 @@ import io.trino.spi.type.Type; import io.trino.sql.parser.ParsingOptions; import io.trino.sql.parser.SqlParser; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.SymbolsExtractor; @@ -266,7 +267,13 @@ private static void assertSimplifies(@Language("SQL") String expression, @Langua private static Expression simplify(@Language("SQL") String expression) { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression, new ParsingOptions())); - return normalize(rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), PLANNER_CONTEXT, createTestingTypeAnalyzer(PLANNER_CONTEXT))); + return normalize(rewrite( + actualExpression, + TEST_SESSION, + new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), + PLANNER_CONTEXT, + createTestingTypeAnalyzer(PLANNER_CONTEXT), + new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION))); } private static Map booleanSymbolTypeMapFor(Expression expression) @@ -342,7 +349,13 @@ private static void assertSimplifiesNumericTypes(String expression, String expec ParsingOptions parsingOptions = new ParsingOptions(); Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression, parsingOptions)); Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected, parsingOptions)); - Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(numericAndBooleanSymbolTypeMapFor(actualExpression)), PLANNER_CONTEXT, createTestingTypeAnalyzer(PLANNER_CONTEXT)); + Expression rewritten = rewrite( + actualExpression, + TEST_SESSION, + new SymbolAllocator(numericAndBooleanSymbolTypeMapFor(actualExpression)), + PLANNER_CONTEXT, + createTestingTypeAnalyzer(PLANNER_CONTEXT), + new ExpressionInterpreter(PLANNER_CONTEXT, TEST_SESSION)); assertEquals( normalize(rewritten), normalize(expectedExpression)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java index 0614e668ee9e..60869c15e3ff 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java @@ -30,6 +30,8 @@ import io.trino.metadata.FunctionManager; import io.trino.metadata.Metadata; import io.trino.security.AccessControl; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; @@ -60,6 +62,7 @@ public class RuleAssert { + private final PlannerContext plannerContext; private final Metadata metadata; private final FunctionManager functionManager; private final TestingStatsCalculator statsCalculator; @@ -75,8 +78,7 @@ public class RuleAssert private final AccessControl accessControl; public RuleAssert( - Metadata metadata, - FunctionManager functionManager, + PlannerContext plannerContext, StatsCalculator statsCalculator, CostCalculator costCalculator, Session session, @@ -84,8 +86,9 @@ public RuleAssert( TransactionManager transactionManager, AccessControl accessControl) { - this.metadata = metadata; - this.functionManager = functionManager; + this.plannerContext = plannerContext; + this.metadata = plannerContext.getMetadata(); + this.functionManager = plannerContext.getFunctionManager(); this.statsCalculator = new TestingStatsCalculator(statsCalculator); this.costCalculator = costCalculator; this.session = session; @@ -225,7 +228,7 @@ private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes()); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(memo), session, symbolAllocator.getTypes()); - + ExpressionInterpreter expressionInterpreter = new ExpressionInterpreter(plannerContext, session); return new Rule.Context() { @Override @@ -272,6 +275,12 @@ public WarningCollector getWarningCollector() { return WarningCollector.NOOP; } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return expressionInterpreter; + } }; } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleTester.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleTester.java index 32175057f61e..d5ebcedc953b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleTester.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleTester.java @@ -102,7 +102,7 @@ public RuleTester(LocalQueryRunner queryRunner) public RuleAssert assertThat(Rule rule) { - return new RuleAssert(metadata, functionManager, queryRunner.getStatsCalculator(), queryRunner.getEstimatedExchangesCostCalculator(), session, rule, transactionManager, accessControl); + return new RuleAssert(queryRunner.getPlannerContext(), queryRunner.getStatsCalculator(), queryRunner.getEstimatedExchangesCostCalculator(), session, rule, transactionManager, accessControl); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java index 904a9285c267..228040018e13 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java @@ -22,6 +22,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.BigintType; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; @@ -32,7 +33,7 @@ import java.util.function.Function; import static io.trino.metadata.FunctionManager.createTestingFunctionManager; -import static io.trino.sql.planner.TypeProvider.empty; +import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.testing.TestingSession.testSessionBuilder; import static org.assertj.core.api.Assertions.assertThatCode; @@ -132,14 +133,45 @@ public void testUpdateWithInvalidNode() private void applyOptimization(Function planProvider) { Metadata metadata = new MockMetadata(); - new BeginTableWrite(metadata, createTestingFunctionManager()) - .optimize( - planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata, testSessionBuilder().build())), - testSessionBuilder().build(), - empty(), - new SymbolAllocator(), - new PlanNodeIdAllocator(), - WarningCollector.NOOP); + Session session = testSessionBuilder().build(); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + ExpressionInterpreter interpreter = new ExpressionInterpreter(plannerContextBuilder().withMetadata(metadata).build(), session); + + new BeginTableWrite(metadata, createTestingFunctionManager()).optimize( + planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata, testSessionBuilder().build())), + new PlanOptimizer.Context() + { + @Override + public Session getSession() + { + return session; + } + + @Override + public SymbolAllocator getSymbolAllocator() + { + return symbolAllocator; + } + + @Override + public PlanNodeIdAllocator getIdAllocator() + { + return idAllocator; + } + + @Override + public WarningCollector getWarningCollector() + { + return WarningCollector.NOOP; + } + + @Override + public ExpressionInterpreter getExpressionInterpreter() + { + return interpreter; + } + }); } private static class MockMetadata diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index eda5b4cc2e0c..9cd56a5002db 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -23,6 +23,7 @@ import io.trino.plugin.tpch.TpchColumnHandle; import io.trino.plugin.tpch.TpchTableHandle; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -477,7 +478,13 @@ private PlanNode removeUnsupportedDynamicFilters(PlanNode root) return getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - PlanNode rewrittenPlan = new RemoveUnsupportedDynamicFilters(plannerContext).optimize(root, session, builder.getTypes(), new SymbolAllocator(), new PlanNodeIdAllocator(), WarningCollector.NOOP); + PlanNode rewrittenPlan = new RemoveUnsupportedDynamicFilters(plannerContext).optimize( + root, + createOptimizerContext( + session, + new SymbolAllocator(builder.getTypes().allTypes()), + new PlanNodeIdAllocator(), + new ExpressionInterpreter(plannerContext, session))); new DynamicFiltersChecker().validate(rewrittenPlan, session, plannerContext, createTestingTypeAnalyzer(plannerContext), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java index 65f05d325480..478416ef613d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -18,7 +18,6 @@ import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.cost.StatsAndCosts; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.plugin.tpch.TpchColumnHandle; @@ -26,6 +25,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.type.BigintType; import io.trino.sql.ExpressionUtils; +import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -123,15 +123,15 @@ private void assertOptimizedPlan(PlanOptimizer optimizer, PlanCreator planCreato PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); PlanBuilder planBuilder = new PlanBuilder(idAllocator, metadata, session); - SymbolAllocator symbolAllocator = new SymbolAllocator(); + SymbolAllocator symbolAllocator = new SymbolAllocator(planBuilder.getTypes().allTypes()); PlanNode plan = planCreator.create(planBuilder, session, metadata); PlanNode optimized = optimizer.optimize( plan, - session, - planBuilder.getTypes(), - symbolAllocator, - idAllocator, - WarningCollector.NOOP); + createOptimizerContext( + session, + symbolAllocator, + idAllocator, + new ExpressionInterpreter(queryRunner.getPlannerContext(), session))); Plan actual = new Plan(optimized, planBuilder.getTypes(), StatsAndCosts.empty()); PlanAssert.assertPlan(session, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actual, pattern);