diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index 56776ce6f3a6..1f9b04e990b9 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -20,12 +20,12 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.security.AllowAllAccessControl; import io.trino.spi.type.Type; +import io.trino.sql.ExpressionUtils; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.ExpressionAnalyzer; import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LiteralEncoder; -import io.trino.sql.planner.LiteralInterpreter; import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeProvider; @@ -39,7 +39,6 @@ import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; @@ -56,6 +55,7 @@ import java.util.OptionalDouble; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison; import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison; @@ -66,6 +66,8 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.DynamicFilters.isDynamicFilter; import static io.trino.sql.ExpressionUtils.and; +import static io.trino.sql.ExpressionUtils.getExpressionTypes; +import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; @@ -74,7 +76,6 @@ import static java.lang.Double.isNaN; import static java.lang.Double.min; import static java.lang.String.format; -import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; public class FilterStatsCalculator @@ -108,7 +109,7 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ { // TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite - Map, Type> expressionTypes = getExpressionTypes(session, predicate, types); + Map, Type> expressionTypes = getExpressionTypes(plannerContext, session, predicate, types); ExpressionInterpreter interpreter = new ExpressionInterpreter(predicate, plannerContext, session, expressionTypes); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -119,21 +120,6 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ return new LiteralEncoder(plannerContext).toExpression(session, value, BOOLEAN); } - private Map, Type> getExpressionTypes(Session session, Expression expression, TypeProvider types) - { - ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( - plannerContext, - new AllowAllAccessControl(), - session, - types, - emptyMap(), - node -> new IllegalStateException("Unexpected node: " + node), - WarningCollector.NOOP, - false); - expressionAnalyzer.analyze(expression, Scope.create()); - return expressionAnalyzer.getExpressionTypes(); - } - private class FilterExpressionStatsCalculatingVisitor extends AstVisitor { @@ -367,14 +353,15 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n Expression left = node.getLeft(); Expression right = node.getRight(); - checkArgument(!(left instanceof Literal && right instanceof Literal), "Literal-to-literal not supported here, should be eliminated earlier"); + checkArgument(!(isEffectivelyLiteral(left) && isEffectivelyLiteral(right)), "Literal-to-literal not supported here, should be eliminated earlier"); if (!(left instanceof SymbolReference) && right instanceof SymbolReference) { // normalize so that symbol is on the left return process(new ComparisonExpression(operator.flip(), right, left)); } - if (left instanceof Literal && !(right instanceof Literal)) { + if (isEffectivelyLiteral(left)) { + verify(!isEffectivelyLiteral(right)); // normalize so that literal is on the right return process(new ComparisonExpression(operator.flip(), right, left)); } @@ -385,8 +372,8 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n SymbolStatsEstimate leftStats = getExpressionStats(left); Optional leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty(); - if (right instanceof Literal) { - OptionalDouble literal = doubleValueFromLiteral(getType(left), (Literal) right); + if (isEffectivelyLiteral(right)) { + OptionalDouble literal = doubleValueFromLiteral(getType(left), right); return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, literal, operator); } @@ -438,9 +425,20 @@ private SymbolStatsEstimate getExpressionStats(Expression expression) return scalarStatsCalculator.calculate(expression, input, session, types); } - private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) + private boolean isEffectivelyLiteral(Expression expression) { - Object literalValue = LiteralInterpreter.evaluate(plannerContext, session, getExpressionTypes(session, literal, types), literal); + return ExpressionUtils.isEffectivelyLiteral(plannerContext, session, expression); + } + + private OptionalDouble doubleValueFromLiteral(Type type, Expression literal) + { + Object literalValue = evaluateConstantExpression( + literal, + type, + plannerContext, + session, + new AllowAllAccessControl(), + ImmutableMap.of()); return toStatsRepresentation(type, literalValue); } } diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index 66cb569d39e1..e6a3bb8daf04 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -50,6 +50,8 @@ import java.util.OptionalDouble; import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; +import static io.trino.sql.ExpressionUtils.getExpressionTypes; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer; import static io.trino.sql.planner.LiteralInterpreter.evaluate; import static io.trino.util.MoreMath.max; @@ -58,7 +60,6 @@ import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; import static java.lang.Math.abs; -import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; public class ScalarStatsCalculator @@ -132,7 +133,7 @@ protected SymbolStatsEstimate visitLiteral(Literal node, Void context) @Override protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) { - Map, Type> expressionTypes = getExpressionTypes(session, node, types); + Map, Type> expressionTypes = getExpressionTypes(plannerContext, session, node, types); ExpressionInterpreter interpreter = new ExpressionInterpreter(node, plannerContext, session, expressionTypes); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -140,7 +141,7 @@ protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) return nullStatsEstimate(); } - if (value instanceof Expression && !(value instanceof Literal)) { + if (value instanceof Expression && !isEffectivelyLiteral(plannerContext, session, (Expression) value)) { // value is not a constant return SymbolStatsEstimate.unknown(); } @@ -152,21 +153,6 @@ protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) .build(); } - private Map, Type> getExpressionTypes(Session session, Expression expression, TypeProvider types) - { - ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( - plannerContext, - new AllowAllAccessControl(), - session, - types, - emptyMap(), - node -> new IllegalStateException("Unexpected node: %s" + node), - WarningCollector.NOOP, - false); - expressionAnalyzer.analyze(expression, Scope.create()); - return expressionAnalyzer.getExpressionTypes(); - } - @Override protected SymbolStatsEstimate visitCast(Cast node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java b/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java index b6a1828ce4d9..c907c48ed64b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/ExpressionUtils.java @@ -14,20 +14,37 @@ package io.trino.sql; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.type.Type; +import io.trino.sql.analyzer.ExpressionAnalyzer; +import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.DeterminismEvaluator; +import io.trino.sql.planner.ExpressionInterpreter; +import io.trino.sql.planner.LiteralEncoder; +import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; import io.trino.sql.tree.ExpressionRewriter; import io.trino.sql.tree.ExpressionTreeRewriter; +import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.GenericDataType; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LambdaExpression; +import io.trino.sql.tree.Literal; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LogicalExpression.Operator; +import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.RowDataType; import io.trino.sql.tree.SymbolReference; @@ -35,12 +52,15 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; import static com.google.common.base.Predicates.not; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.metadata.LiteralFunction.LITERAL_FUNCTION_NAME; +import static io.trino.metadata.ResolvedFunction.isResolved; import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static java.util.Objects.requireNonNull; @@ -255,6 +275,59 @@ public static Function expressionOrNullSymbols(Predicate }; } + /** + * Returns whether expression is effectively literal. An effectitvely literal expression is a simple constant value, or null, + * in either {@link Literal} form, or other form returned by {@link LiteralEncoder}. In particular, other constant expressions + * like a deterministic function call with constant arguments are not considered effectitvely literal. + */ + public static boolean isEffectivelyLiteral(PlannerContext plannerContext, Session session, Expression expression) + { + if (expression instanceof Literal) { + return true; + } + if (expression instanceof Cast) { + return ((Cast) expression).getExpression() instanceof Literal + // a Cast(Literal(...)) can fail, so this requires verification + && constantExpressionEvaluatesSuccessfully(plannerContext, session, expression); + } + if (expression instanceof FunctionCall) { + QualifiedName functionName = ((FunctionCall) expression).getName(); + if (isResolved(functionName)) { + ResolvedFunction resolvedFunction = plannerContext.getMetadata().decodeFunction(functionName); + return LITERAL_FUNCTION_NAME.equals(resolvedFunction.getSignature().getName()); + } + } + + return false; + } + + private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext plannerContext, Session session, Expression constantExpression) + { + Map, Type> types = getExpressionTypes(plannerContext, session, constantExpression, TypeProvider.empty()); + ExpressionInterpreter interpreter = new ExpressionInterpreter(constantExpression, plannerContext, session, types); + Object literalValue = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + return !(literalValue instanceof Expression); + } + + /** + * @deprecated Use {@link io.trino.sql.planner.TypeAnalyzer#getTypes(Session, TypeProvider, Expression)}. + */ + @Deprecated + public static Map, Type> getExpressionTypes(PlannerContext plannerContext, Session session, Expression expression, TypeProvider types) + { + ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( + plannerContext, + new AllowAllAccessControl(), + session, + types, + ImmutableMap.of(), + node -> new IllegalStateException("Unexpected node: " + node), + WarningCollector.NOOP, + false); + expressionAnalyzer.analyze(expression, Scope.create()); + return expressionAnalyzer.getExpressionTypes(); + } + /** * Removes duplicate deterministic expressions. Preserves the relative order * of the expressions in the list. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index 36e7edf80a22..ea287a8bb82c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -134,6 +134,7 @@ import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.DynamicFilters.isDynamicFilter; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant; import static io.trino.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; @@ -238,7 +239,7 @@ public static Object evaluateConstantExpression( // expressionInterpreter/optimizer only understands a subset of expression types // TODO: remove this when the new expression tree is implemented - Expression canonicalized = canonicalizeExpression(rewrite, analyzer.getExpressionTypes(), plannerContext.getMetadata(), session); + Expression canonicalized = canonicalizeExpression(rewrite, analyzer.getExpressionTypes(), plannerContext, session); // The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis // to re-analyze coercions that might be necessary @@ -546,7 +547,6 @@ protected Object visitCoalesceExpression(CoalesceExpression node, Object context private List processOperands(CoalesceExpression node, Object context) { - Type type = type(node); List newOperands = new ArrayList<>(); Set uniqueNewOperands = new HashSet<>(); for (Expression operand : node.getOperands()) { @@ -559,14 +559,18 @@ private List processOperands(CoalesceExpression node, Object context) newOperands.add(nestedOperand); } // This operand can be evaluated to a non-null value. Remaining operands can be skipped. - if (nestedOperand instanceof Literal) { + if (isEffectivelyLiteral(plannerContext, session, nestedOperand)) { + verify( + !(nestedOperand instanceof NullLiteral) && !(nestedOperand instanceof Cast && ((Cast) nestedOperand).getExpression() instanceof NullLiteral), + "Null operand should have been removed by recursive coalesce processing"); return newOperands; } } } else if (value instanceof Expression) { + verify(!(value instanceof NullLiteral), "Null value is expected to be represented as null, not NullLiteral"); // Skip duplicates unless they are non-deterministic. - Expression expression = toExpression(value, type); + Expression expression = (Expression) value; if (!isDeterministic(expression, metadata) || uniqueNewOperands.add(expression)) { newOperands.add(expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 1d5b239f049d..f19d367e76d0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -349,7 +349,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of( - new InlineProjections(typeAnalyzer), + new InlineProjections(plannerContext, typeAnalyzer), new RemoveRedundantIdentityProjections())); Set> simplifyOptimizerRules = ImmutableSet.>builder() @@ -358,7 +358,7 @@ public PlanOptimizers( .addAll(new PushCastIntoRow().rules()) .addAll(new UnwrapCastInComparison(plannerContext, typeAnalyzer).rules()) .addAll(new RemoveDuplicateConditions(metadata).rules()) - .addAll(new CanonicalizeExpressions(metadata, typeAnalyzer).rules()) + .addAll(new CanonicalizeExpressions(plannerContext, typeAnalyzer).rules()) .add(new RemoveTrivialFilters()) .build(); IterativeOptimizer simplifyOptimizer = new IterativeOptimizer( @@ -397,7 +397,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.>builder() - .addAll(new CanonicalizeExpressions(metadata, typeAnalyzer).rules()) + .addAll(new CanonicalizeExpressions(plannerContext, typeAnalyzer).rules()) .add(new OptimizeRowPattern()) .build()), new IterativeOptimizer( @@ -417,7 +417,7 @@ public PlanOptimizers( new EvaluateEmptyIntersect(), new RemoveEmptyExceptBranches(), new MergeFilters(metadata), - new InlineProjections(typeAnalyzer), + new InlineProjections(plannerContext, typeAnalyzer), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), new EvaluateZeroSample(), @@ -444,7 +444,7 @@ public PlanOptimizers( new PruneCountAggregationOverScalar(metadata), new PruneOrderByInAggregation(metadata), new RewriteSpatialPartitioningAggregation(plannerContext), - new SimplifyCountOverConstant(metadata))) + new SimplifyCountOverConstant(plannerContext))) .build()), new IterativeOptimizer( metadata, @@ -534,7 +534,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of( - new InlineProjections(typeAnalyzer), + new InlineProjections(plannerContext, typeAnalyzer), new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject(), new RemoveAggregationInSemiJoin(), @@ -697,7 +697,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new EliminateCrossJoins(metadata, typeAnalyzer))), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + ImmutableSet.of(new EliminateCrossJoins(plannerContext, typeAnalyzer))), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again new StatsRecordingPlanOptimizer( optimizerStats, new PredicatePushDown(plannerContext, typeAnalyzer, true, false)), @@ -750,7 +750,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new ReorderJoins(metadata, costComparator, typeAnalyzer)))); + ImmutableSet.of(new ReorderJoins(plannerContext, costComparator, typeAnalyzer)))); builder.add(new OptimizeMixedDistinctAggregations(metadata)); builder.add(new IterativeOptimizer( @@ -772,7 +772,7 @@ public PlanOptimizers( ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) .addAll(new ExtractSpatialJoins(plannerContext, splitManager, pageSourceManager, typeAnalyzer).rules()) - .add(new InlineProjections(typeAnalyzer)) + .add(new InlineProjections(plannerContext, typeAnalyzer)) .build())); builder.add(new IterativeOptimizer( @@ -884,7 +884,7 @@ public PlanOptimizers( ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) .add(new PushRemoteExchangeThroughAssignUniqueId()) - .add(new InlineProjections(typeAnalyzer)) + .add(new InlineProjections(plannerContext, typeAnalyzer)) .build())); // Optimizers above this don't understand local exchanges, so be careful moving this. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java index 45bf05446288..74ddb648aa40 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java @@ -23,6 +23,7 @@ import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.FunctionCallBuilder; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; @@ -39,7 +40,6 @@ import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullLiteral; @@ -57,6 +57,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.MULTIPLY; @@ -64,16 +65,16 @@ public final class CanonicalizeExpressionRewriter { - public static Expression canonicalizeExpression(Expression expression, Map, Type> expressionTypes, Metadata metadata, Session session) + public static Expression canonicalizeExpression(Expression expression, Map, Type> expressionTypes, PlannerContext plannerContext, Session session) { - return ExpressionTreeRewriter.rewriteWith(new Visitor(session, metadata, expressionTypes), expression); + return ExpressionTreeRewriter.rewriteWith(new Visitor(session, plannerContext, expressionTypes), expression); } private CanonicalizeExpressionRewriter() {} - public static Expression rewrite(Expression expression, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types) + public static Expression rewrite(Expression expression, Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types) { - requireNonNull(metadata, "metadata is null"); + requireNonNull(plannerContext, "plannerContext is null"); requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { @@ -81,20 +82,22 @@ public static Expression rewrite(Expression expression, Session session, Metadat } Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); - return ExpressionTreeRewriter.rewriteWith(new Visitor(session, metadata, expressionTypes), expression); + return ExpressionTreeRewriter.rewriteWith(new Visitor(session, plannerContext, expressionTypes), expression); } private static class Visitor extends ExpressionRewriter { private final Session session; + private final PlannerContext plannerContext; private final Metadata metadata; private final Map, Type> expressionTypes; - public Visitor(Session session, Metadata metadata, Map, Type> expressionTypes) + public Visitor(Session session, PlannerContext plannerContext, Map, Type> expressionTypes) { this.session = session; - this.metadata = metadata; + this.plannerContext = plannerContext; + this.metadata = plannerContext.getMetadata(); this.expressionTypes = expressionTypes; } @@ -291,20 +294,16 @@ public Expression rewriteFormat(Format node, Void context, ExpressionTreeRewrite .addArgument(RowType.anonymous(argumentTypes.subList(1, arguments.size())), new Row(arguments.subList(1, arguments.size()))) .build(); } - } - private static boolean isConstant(Expression expression) - { - // Current IR has no way to represent typed constants. It encodes simple ones as Cast(Literal) - // This is the simplest possible check that - // 1) doesn't require ExpressionInterpreter.optimize(), which is not cheap - // 2) doesn't try to duplicate all the logic in LiteralEncoder - // 3) covers a sufficient portion of the use cases that occur in practice - // TODO: this should eventually be removed when IR includes types - if (expression instanceof Cast && ((Cast) expression).getExpression() instanceof Literal) { - return true; + private boolean isConstant(Expression expression) + { + // Current IR has no way to represent typed constants. It encodes simple ones as Cast(Literal) + // This is the simplest possible check that + // 1) doesn't require ExpressionInterpreter.optimize(), which is not cheap + // 2) doesn't try to duplicate all the logic in LiteralEncoder + // 3) covers a sufficient portion of the use cases that occur in practice + // TODO: this should eventually be removed when IR includes types + return isEffectivelyLiteral(plannerContext, session, expression); } - - return expression instanceof Literal; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressions.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressions.java index b2b52217adde..a1c2e25157ee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressions.java @@ -13,7 +13,7 @@ */ package io.trino.sql.planner.iterative.rule; -import io.trino.metadata.Metadata; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.TypeAnalyzer; import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.rewrite; @@ -21,8 +21,8 @@ public class CanonicalizeExpressions extends ExpressionRewriteRuleSet { - public CanonicalizeExpressions(Metadata metadata, TypeAnalyzer typeAnalyzer) + public CanonicalizeExpressions(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) { - super((expression, context) -> rewrite(expression, context.getSession(), metadata, typeAnalyzer, context.getSymbolAllocator().getTypes())); + super((expression, context) -> rewrite(expression, context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes())); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/EliminateCrossJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/EliminateCrossJoins.java index 539fe3ac0d66..8623a66b5dd3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -20,7 +20,7 @@ import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; @@ -55,12 +55,12 @@ public class EliminateCrossJoins implements Rule { private static final Pattern PATTERN = join(); - private final Metadata metadata; + private final PlannerContext plannerContext; private final TypeAnalyzer typeAnalyzer; - public EliminateCrossJoins(Metadata metadata, TypeAnalyzer typeAnalyzer) + public EliminateCrossJoins(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) { - this.metadata = metadata; + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @@ -81,7 +81,7 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode node, Captures captures, Context context) { - JoinGraph joinGraph = JoinGraph.buildFrom(metadata, node, context.getLookup(), context.getIdAllocator(), context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + JoinGraph joinGraph = JoinGraph.buildFrom(plannerContext, node, context.getLookup(), context.getIdAllocator(), context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) { return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java index fdd8b2675d47..b7991165f2da 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java @@ -21,6 +21,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.spi.type.RowType; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.TypeAnalyzer; @@ -30,7 +31,6 @@ import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SymbolReference; import io.trino.sql.tree.TryExpression; @@ -43,6 +43,7 @@ import java.util.stream.Collectors; import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; import static io.trino.sql.planner.plan.Patterns.project; import static io.trino.sql.planner.plan.Patterns.source; @@ -62,10 +63,13 @@ public class InlineProjections private static final Pattern PATTERN = project() .with(source().matching(project().capturedAs(CHILD))); + + private final PlannerContext plannerContext; private final TypeAnalyzer typeAnalyzer; - public InlineProjections(TypeAnalyzer typeAnalyzer) + public InlineProjections(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @@ -80,19 +84,19 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); - return inlineProjections(parent, child, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()) + return inlineProjections(plannerContext, parent, child, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()) .map(Result::ofPlanNode) .orElse(Result.empty()); } - static Optional inlineProjections(ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) + static Optional inlineProjections(PlannerContext plannerContext, ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) { // squash identity projections if (parent.isIdentity() && child.isIdentity()) { return Optional.of((ProjectNode) parent.replaceChildren(ImmutableList.of(child.getSource()))); } - Set targets = extractInliningTargets(parent, child, session, typeAnalyzer, types); + Set targets = extractInliningTargets(plannerContext, parent, child, session, typeAnalyzer, types); if (targets.isEmpty()) { return Optional.empty(); } @@ -157,7 +161,7 @@ private static Expression inlineReferences(Expression expression, Assignments as return inlineSymbols(mapping, expression); } - private static Set extractInliningTargets(ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) + private static Set extractInliningTargets(PlannerContext plannerContext, ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) { // candidates for inlining are // 1. references to simple constants or symbol references @@ -177,7 +181,7 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod // find references to simple constants or symbol references Set basicReferences = dependencies.keySet().stream() - .filter(input -> child.getAssignments().get(input) instanceof Literal || child.getAssignments().get(input) instanceof SymbolReference) + .filter(input -> isEffectivelyLiteral(plannerContext, session, child.getAssignments().get(input)) || child.getAssignments().get(input) instanceof SymbolReference) .filter(input -> !child.getAssignments().isIdentity(input)) // skip identities, otherwise, this rule will keep firing forever .collect(toSet()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java index 70dcec79dcbf..3545e08fd5ce 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.trino.Session; -import io.trino.metadata.Metadata; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; @@ -48,7 +48,7 @@ public final class PushProjectionThroughJoin { public static Optional pushProjectionThroughJoin( - Metadata metadata, + PlannerContext plannerContext, ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, @@ -56,7 +56,7 @@ public static Optional pushProjectionThroughJoin( TypeAnalyzer typeAnalyzer, TypeProvider types) { - if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression, metadata))) { + if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression, plannerContext.getMetadata()))) { return Optional.empty(); } @@ -117,12 +117,14 @@ else if (rightChild.getOutputSymbols().containsAll(symbols)) { joinNode.getId(), joinNode.getType(), inlineProjections( + plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments), lookup, session, typeAnalyzer, types), inlineProjections( + plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), lookup, session, @@ -141,7 +143,13 @@ else if (rightChild.getOutputSymbols().containsAll(symbols)) { joinNode.getReorderJoinStatsAndCost())); } - private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup lookup, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) + private static PlanNode inlineProjections( + PlannerContext plannerContext, + ProjectNode parentProjection, + Lookup lookup, + Session session, + TypeAnalyzer typeAnalyzer, + TypeProvider types) { PlanNode child = lookup.resolve(parentProjection.getSource()); if (!(child instanceof ProjectNode)) { @@ -149,8 +157,8 @@ private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup l } ProjectNode childProjection = (ProjectNode) child; - return InlineProjections.inlineProjections(parentProjection, childProjection, session, typeAnalyzer, types) - .map(node -> inlineProjections(node, lookup, session, typeAnalyzer, types)) + return InlineProjections.inlineProjections(plannerContext, parentProjection, childProjection, session, typeAnalyzer, types) + .map(node -> inlineProjections(plannerContext, node, lookup, session, typeAnalyzer, types)) .orElse(parentProjection); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java index 849dca596c59..741e13ff3bad 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java @@ -33,6 +33,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.Metadata; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.EqualityInference; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; @@ -103,17 +104,17 @@ public class ReorderJoins private final Pattern pattern; private final TypeAnalyzer typeAnalyzer; - private final Metadata metadata; + private final PlannerContext plannerContext; private final CostComparator costComparator; - public ReorderJoins(Metadata metadata, CostComparator costComparator, TypeAnalyzer typeAnalyzer) + public ReorderJoins(PlannerContext plannerContext, CostComparator costComparator, TypeAnalyzer typeAnalyzer) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.costComparator = requireNonNull(costComparator, "costComparator is null"); this.pattern = join().matching( joinNode -> joinNode.getDistributionType().isEmpty() && joinNode.getType() == INNER - && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), metadata)); + && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), plannerContext.getMetadata())); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @@ -133,7 +134,7 @@ public boolean isEnabled(Session session) public Result apply(JoinNode joinNode, Captures captures, Context context) { // try reorder joins with projection pushdown first - MultiJoinNode multiJoinNode = toMultiJoinNode(metadata, joinNode, context, true, typeAnalyzer); + MultiJoinNode multiJoinNode = toMultiJoinNode(plannerContext, joinNode, context, true, typeAnalyzer); JoinEnumerationResult resultWithProjectionPushdown = chooseJoinOrder(multiJoinNode, context); if (resultWithProjectionPushdown.getPlanNode().isEmpty()) { return Result.empty(); @@ -144,7 +145,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) } // try reorder joins without projection pushdown - multiJoinNode = toMultiJoinNode(metadata, joinNode, context, false, typeAnalyzer); + multiJoinNode = toMultiJoinNode(plannerContext, joinNode, context, false, typeAnalyzer); JoinEnumerationResult resultWithoutProjectionPushdown = chooseJoinOrder(multiJoinNode, context); if (resultWithoutProjectionPushdown.getPlanNode().isEmpty() || costComparator.compare(context.getSession(), resultWithProjectionPushdown.cost, resultWithoutProjectionPushdown.cost) < 0) { @@ -157,7 +158,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) private JoinEnumerationResult chooseJoinOrder(MultiJoinNode multiJoinNode, Context context) { JoinEnumerator joinEnumerator = new JoinEnumerator( - metadata, + plannerContext.getMetadata(), costComparator, multiJoinNode.getFilter(), context); @@ -549,10 +550,10 @@ public boolean equals(Object obj) && this.pushedProjectionThroughJoin == other.pushedProjectionThroughJoin; } - static MultiJoinNode toMultiJoinNode(Metadata metadata, JoinNode joinNode, Context context, boolean pushProjectionsThroughJoin, TypeAnalyzer typeAnalyzer) + static MultiJoinNode toMultiJoinNode(PlannerContext plannerContext, JoinNode joinNode, Context context, boolean pushProjectionsThroughJoin, TypeAnalyzer typeAnalyzer) { return toMultiJoinNode( - metadata, + plannerContext, joinNode, context.getLookup(), context.getIdAllocator(), @@ -564,7 +565,7 @@ static MultiJoinNode toMultiJoinNode(Metadata metadata, JoinNode joinNode, Conte } static MultiJoinNode toMultiJoinNode( - Metadata metadata, + PlannerContext plannerContext, JoinNode joinNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, @@ -575,13 +576,13 @@ static MultiJoinNode toMultiJoinNode( TypeProvider types) { // the number of sources is the number of joins + 1 - return new JoinNodeFlattener(metadata, joinNode, lookup, planNodeIdAllocator, joinLimit + 1, pushProjectionsThroughJoin, session, typeAnalyzer, types) + return new JoinNodeFlattener(plannerContext, joinNode, lookup, planNodeIdAllocator, joinLimit + 1, pushProjectionsThroughJoin, session, typeAnalyzer, types) .toMultiJoinNode(); } private static class JoinNodeFlattener { - private final Metadata metadata; + private final PlannerContext plannerContext; private final Session session; private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; @@ -597,7 +598,7 @@ private static class JoinNodeFlattener private boolean pushedProjectionThroughJoin; JoinNodeFlattener( - Metadata metadata, + PlannerContext plannerContext, JoinNode node, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, @@ -607,7 +608,7 @@ private static class JoinNodeFlattener TypeAnalyzer typeAnalyzer, TypeProvider types) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); this.outputSymbols = node.getOutputSymbols(); @@ -631,7 +632,7 @@ private void flattenNode(PlanNode node, int limit) return; } - Optional rewrittenNode = pushProjectionThroughJoin(metadata, (ProjectNode) resolved, lookup, planNodeIdAllocator, session, typeAnalyzer, types); + Optional rewrittenNode = pushProjectionThroughJoin(plannerContext, (ProjectNode) resolved, lookup, planNodeIdAllocator, session, typeAnalyzer, types); if (rewrittenNode.isEmpty()) { sources.add(node); return; @@ -649,7 +650,7 @@ private void flattenNode(PlanNode node, int limit) } JoinNode joinNode = (JoinNode) resolved; - if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), metadata) || joinNode.getDistributionType().isPresent()) { + if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), plannerContext.getMetadata()) || joinNode.getDistributionType().isPresent()) { sources.add(node); return; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 3e4c7aa8638d..d8a25f30c0b7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -14,20 +14,22 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.BoundSignature; -import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.security.AllowAllAccessControl; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Literal; -import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.SymbolReference; @@ -36,7 +38,10 @@ import java.util.Map.Entry; import java.util.Optional; +import static com.google.common.base.Verify.verify; import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; +import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static io.trino.sql.planner.plan.Patterns.aggregation; import static io.trino.sql.planner.plan.Patterns.project; import static io.trino.sql.planner.plan.Patterns.source; @@ -50,11 +55,11 @@ public class SimplifyCountOverConstant private static final Pattern PATTERN = aggregation() .with(source().matching(project().capturedAs(CHILD))); - private final Metadata metadata; + private final PlannerContext plannerContext; - public SimplifyCountOverConstant(Metadata metadata) + public SimplifyCountOverConstant(PlannerContext plannerContext) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); } @Override @@ -71,13 +76,13 @@ public Result apply(AggregationNode parent, Captures captures, Context context) boolean changed = false; Map aggregations = new LinkedHashMap<>(parent.getAggregations()); - ResolvedFunction countFunction = metadata.resolveFunction(context.getSession(), QualifiedName.of("count"), ImmutableList.of()); + ResolvedFunction countFunction = plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of("count"), ImmutableList.of()); for (Entry entry : parent.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); - if (isCountOverConstant(aggregation, child.getAssignments())) { + if (isCountOverConstant(context.getSession(), aggregation, child.getAssignments())) { changed = true; aggregations.put(symbol, new AggregationNode.Aggregation( countFunction, @@ -104,7 +109,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) parent.getGroupIdSymbol())); } - private static boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Assignments inputs) + private boolean isCountOverConstant(Session session, AggregationNode.Aggregation aggregation, Assignments inputs) { BoundSignature signature = aggregation.getResolvedFunction().getSignature(); if (!signature.getName().equals("count") || signature.getArgumentTypes().size() != 1) { @@ -116,6 +121,20 @@ private static boolean isCountOverConstant(AggregationNode.Aggregation aggregati argument = inputs.get(Symbol.from(argument)); } - return argument instanceof Literal && !(argument instanceof NullLiteral); + if (isEffectivelyLiteral(plannerContext, session, argument)) { + Object value = evaluateConstantExpression( + argument, + ImmutableMap.of(), + ImmutableSet.of(), + plannerContext, + session, + new AllowAllAccessControl(), + ImmutableSet.of(), + ImmutableMap.of()); + verify(!(value instanceof Expression)); + return value != null; + } + + return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index 6be6de932a02..3192ad55cec9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -59,7 +59,6 @@ import io.trino.sql.planner.plan.TopNRankingNode; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.SymbolReference; import java.util.ArrayList; @@ -77,6 +76,7 @@ import static io.trino.SystemSessionProperties.getTaskWriterCount; import static io.trino.SystemSessionProperties.isDistributedSortEnabled; import static io.trino.SystemSessionProperties.isSpillEnabled; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -178,7 +178,7 @@ public PlanWithProperties visitProject(ProjectNode node, StreamPreferredProperti { // Special handling for trivial projections. Applies to identity and renaming projections, and constants // It might be extended to handle other low-cost projections. - if (node.getAssignments().getExpressions().stream().allMatch(expression -> expression instanceof SymbolReference || expression instanceof Literal)) { + if (node.getAssignments().getExpressions().stream().allMatch(expression -> expression instanceof SymbolReference || isEffectivelyLiteral(plannerContext, session, expression))) { if (parentPreferences.isSingleStreamPreferred()) { // Do not enforce gathering exchange below project: // - if project's source is single stream, no exchanges will be added around project, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 16341610f9ea..2d0b314294a6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -61,7 +61,6 @@ import io.trino.sql.tree.BooleanLiteral; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullLiteral; @@ -97,6 +96,7 @@ import static io.trino.sql.ExpressionUtils.combineConjuncts; import static io.trino.sql.ExpressionUtils.extractConjuncts; import static io.trino.sql.ExpressionUtils.filterDeterministicConjuncts; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; @@ -295,7 +295,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) - .map(conjunct -> canonicalizeExpression(conjunct, typeAnalyzer.getTypes(session, types, conjunct), metadata, session)) // normalize expressions to a form that unwrapCasts understands + .map(conjunct -> canonicalizeExpression(conjunct, typeAnalyzer.getTypes(session, types, conjunct), plannerContext, session)) // normalize expressions to a form that unwrapCasts understands .map(conjunct -> unwrapCasts(session, plannerContext, typeAnalyzer, types, conjunct)) .collect(Collectors.toList()); @@ -331,7 +331,7 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) return dependencies.entrySet().stream() .allMatch(entry -> entry.getValue() == 1 - || node.getAssignments().get(entry.getKey()) instanceof Literal + || isEffectivelyLiteral(plannerContext, session, node.getAssignments().get(entry.getKey())) || node.getAssignments().get(entry.getKey()) instanceof SymbolReference); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/joins/JoinGraph.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/joins/JoinGraph.java index 044e0060f5d1..ff7db9fa64d9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/joins/JoinGraph.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/joins/JoinGraph.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; import io.trino.Session; -import io.trino.metadata.Metadata; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeAnalyzer; @@ -62,7 +62,7 @@ public class JoinGraph * Builds {@link JoinGraph} containing {@code plan} node. */ public static JoinGraph buildFrom( - Metadata metadata, + PlannerContext plannerContext, PlanNode plan, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, @@ -70,7 +70,7 @@ public static JoinGraph buildFrom( TypeAnalyzer typeAnalyzer, TypeProvider types) { - return plan.accept(new Builder(metadata, lookup, planNodeIdAllocator, session, typeAnalyzer, types), new Context()); + return plan.accept(new Builder(plannerContext, lookup, planNodeIdAllocator, session, typeAnalyzer, types), new Context()); } public JoinGraph(PlanNode node) @@ -202,16 +202,16 @@ private JoinGraph joinWith(JoinGraph other, List joinCl private static class Builder extends PlanVisitor { - private final Metadata metadata; + private final PlannerContext plannerContext; private final Lookup lookup; private final PlanNodeIdAllocator planNodeIdAllocator; private final Session session; private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; - private Builder(Metadata metadata, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) + private Builder(PlannerContext plannerContext, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.lookup = requireNonNull(lookup, "lookup cannot be null"); this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null"); this.session = requireNonNull(session, "session is null"); @@ -257,7 +257,7 @@ public JoinGraph visitJoin(JoinNode node, Context context) @Override public JoinGraph visitProject(ProjectNode node, Context context) { - Optional rewrittenNode = pushProjectionThroughJoin(metadata, node, lookup, planNodeIdAllocator, session, typeAnalyzer, types); + Optional rewrittenNode = pushProjectionThroughJoin(plannerContext, node, lookup, planNodeIdAllocator, session, typeAnalyzer, types); if (rewrittenNode.isPresent()) { return rewrittenNode.get().accept(this, context); } diff --git a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java index 31aa83bf8a8a..3e748dfc8e5d 100644 --- a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java @@ -13,25 +13,15 @@ */ package io.trino.util; -import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Literal; -import io.trino.sql.tree.SymbolReference; -import java.util.Collection; import java.util.List; -import java.util.Set; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.sql.ExpressionUtils.extractConjuncts; -import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; -import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; public final class SpatialJoinUtils { @@ -108,59 +98,4 @@ private static boolean isSTDistance(Expression expression) return false; } - - public static boolean isSpatialJoinFilter(PlanNode left, PlanNode right, Expression filterExpression) - { - List functionCalls = extractSupportedSpatialFunctions(filterExpression); - for (FunctionCall functionCall : functionCalls) { - if (isSpatialJoinFilter(left, right, functionCall)) { - return true; - } - } - - List spatialComparisons = extractSupportedSpatialComparisons(filterExpression); - for (ComparisonExpression spatialComparison : spatialComparisons) { - if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) { - // ST_Distance(a, b) <= r - Expression radius = spatialComparison.getRight(); - if (radius instanceof Literal || (radius instanceof SymbolReference && getSymbolReferences(right.getOutputSymbols()).contains(radius))) { - if (isSpatialJoinFilter(left, right, (FunctionCall) spatialComparison.getLeft())) { - return true; - } - } - } - } - - return false; - } - - private static boolean isSpatialJoinFilter(PlanNode left, PlanNode right, FunctionCall spatialFunction) - { - List arguments = spatialFunction.getArguments(); - verify(arguments.size() == 2); - if (!(arguments.get(0) instanceof SymbolReference) || !(arguments.get(1) instanceof SymbolReference)) { - return false; - } - - SymbolReference firstSymbol = (SymbolReference) arguments.get(0); - SymbolReference secondSymbol = (SymbolReference) arguments.get(1); - - Set probeSymbols = getSymbolReferences(left.getOutputSymbols()); - Set buildSymbols = getSymbolReferences(right.getOutputSymbols()); - - if (probeSymbols.contains(firstSymbol) && buildSymbols.contains(secondSymbol)) { - return true; - } - - if (probeSymbols.contains(secondSymbol) && buildSymbols.contains(firstSymbol)) { - return true; - } - - return false; - } - - private static Set getSymbolReferences(Collection symbols) - { - return symbols.stream().map(Symbol::toSymbolReference).collect(toImmutableSet()); - } } diff --git a/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java b/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java index c4d2af59be9b..c84920114415 100644 --- a/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java @@ -107,7 +107,7 @@ private static Expression planExpressionInExistingTx(PlannerContext plannerConte Expression rewritten = rewriteIdentifiersToSymbolReferences(expression); rewritten = DesugarLikeRewriter.rewrite(rewritten, transactionSession, plannerContext.getMetadata(), createTestingTypeAnalyzer(plannerContext), typeProvider); rewritten = DesugarArrayConstructorRewriter.rewrite(rewritten, transactionSession, plannerContext.getMetadata(), createTestingTypeAnalyzer(plannerContext), typeProvider); - rewritten = CanonicalizeExpressionRewriter.rewrite(rewritten, transactionSession, plannerContext.getMetadata(), createTestingTypeAnalyzer(plannerContext), typeProvider); + rewritten = CanonicalizeExpressionRewriter.rewrite(rewritten, transactionSession, plannerContext, createTestingTypeAnalyzer(plannerContext), typeProvider); return resolveFunctionCalls(plannerContext, transactionSession, typeProvider, rewritten); } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index ec079a58ba78..4cc5422d7c5b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -1939,7 +1939,7 @@ static Expression planExpression(@Language("SQL") String expression) parsedExpression = CanonicalizeExpressionRewriter.rewrite( parsedExpression, transactionSession, - PLANNER_CONTEXT.getMetadata(), + PLANNER_CONTEXT, createTestingTypeAnalyzer(PLANNER_CONTEXT), SYMBOL_TYPES); return parsedExpression; diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index 4224fa04531a..7208d2309ccc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -14,12 +14,8 @@ package io.trino.sql; import com.google.common.collect.ImmutableMap; -import io.trino.execution.warnings.WarningCollector; -import io.trino.security.AllowAllAccessControl; import io.trino.spi.type.Decimals; import io.trino.spi.type.Type; -import io.trino.sql.analyzer.ExpressionAnalyzer; -import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.NoOpSymbolResolver; @@ -42,7 +38,6 @@ import static io.trino.sql.planner.iterative.rule.test.PlanBuilder.expression; import static io.trino.sql.relational.Expressions.constant; import static io.trino.testing.assertions.Assert.assertEquals; -import static java.util.Collections.emptyMap; public class TestSqlToRowExpressionTranslator { @@ -105,16 +100,6 @@ private Expression simplifyExpression(Expression expression) private Map, Type> getExpressionTypes(Expression expression) { - ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( - PLANNER_CONTEXT, - new AllowAllAccessControl(), - TEST_SESSION, - TypeProvider.empty(), - emptyMap(), - node -> new IllegalStateException("Unexpected node: " + node), - WarningCollector.NOOP, - false); - expressionAnalyzer.analyze(expression, Scope.create()); - return expressionAnalyzer.getExpressionTypes(); + return ExpressionUtils.getExpressionTypes(PLANNER_CONTEXT, TEST_SESSION, expression, TypeProvider.empty()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java index 048fd379caf7..e57d9e258e60 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionNullability; import io.trino.metadata.LiteralFunction; @@ -31,8 +30,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarcharType; -import io.trino.sql.analyzer.ExpressionAnalyzer; -import io.trino.sql.analyzer.Scope; +import io.trino.sql.ExpressionUtils; import io.trino.sql.tree.Expression; import io.trino.sql.tree.NodeRef; import io.trino.transaction.TestingTransactionManager; @@ -65,6 +63,7 @@ import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.SqlFormatter.formatSql; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.transaction.TransactionBuilder.transaction; @@ -76,7 +75,6 @@ import static io.trino.type.UnknownType.UNKNOWN; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Collections.emptyMap; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -269,6 +267,7 @@ private void assertEncode(Object value, Type type, String expected) private void assertEncodeCaseInsensitively(Object value, Type type, String expected) { Expression expression = encoder.toExpression(TEST_SESSION, value, type); + assertTrue(isEffectivelyLiteral(PLANNER_CONTEXT, TEST_SESSION, expression), "isEffectivelyLiteral returned false for: " + expression); assertEquals(getExpressionType(expression), type); assertEquals(getExpressionValue(expression), value); assertEqualsIgnoreCase(formatSql(expression), expected); @@ -277,6 +276,7 @@ private void assertEncodeCaseInsensitively(Object value, Type type, String expec private void assertRoundTrip(T value, Type type, BiPredicate predicate) { Expression expression = encoder.toExpression(TEST_SESSION, value, type); + assertTrue(isEffectivelyLiteral(PLANNER_CONTEXT, TEST_SESSION, expression), "isEffectivelyLiteral returned false for: " + expression); assertEquals(getExpressionType(expression), type); @SuppressWarnings("unchecked") T decodedValue = (T) getExpressionValue(expression); @@ -301,18 +301,7 @@ private Map, Type> getExpressionTypes(Expression expression) return transaction(new TestingTransactionManager(), new AllowAllAccessControl()) .singleStatement() .execute(TEST_SESSION, transactionSession -> { - ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( - PLANNER_CONTEXT, - new AllowAllAccessControl(), - transactionSession, - TypeProvider.empty(), - emptyMap(), - node -> new IllegalStateException("Unexpected node: " + node), - WarningCollector.NOOP, - false); - expressionAnalyzer.analyze(expression, Scope.create()); - Map, Type> expressionTypes = expressionAnalyzer.getExpressionTypes(); - return expressionTypes; + return ExpressionUtils.getExpressionTypes(PLANNER_CONTEXT, transactionSession, expression, TypeProvider.empty()); }); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 56766fe4d1a2..0e367fc0a154 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -995,7 +995,7 @@ public void testPruneCountAggregationOverScalar() } @Test - public void testInlineCountOverConstantExpression() + public void testInlineCountOverLiteral() { assertPlan( "SELECT regionkey, count(1) FROM nation GROUP BY regionkey", @@ -1006,6 +1006,18 @@ public void testInlineCountOverConstantExpression() tableScan("nation", ImmutableMap.of("regionkey", "regionkey"))))); } + @Test + public void testInlineCountOverEffectivelyLiteral() + { + assertPlan( + "SELECT regionkey, count(CAST(DECIMAL '1' AS decimal(8,4))) FROM nation GROUP BY regionkey", + anyTree( + aggregation( + ImmutableMap.of("count_0", functionCall("count", ImmutableList.of())), + PARTIAL, + tableScan("nation", ImmutableMap.of("regionkey", "regionkey"))))); + } + @Test public void testPickTableLayoutWithFilter() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 574bcc68f85a..8ca3d1223740 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -72,6 +72,7 @@ import io.trino.sql.tree.Row; import io.trino.sql.tree.SortItem; import io.trino.sql.tree.WindowFrame; +import org.intellij.lang.annotations.Language; import java.util.ArrayList; import java.util.Collection; @@ -1015,7 +1016,7 @@ public static RvalueMatcher columnReference(String tableName, String columnName) return new ColumnReference(tableName, columnName); } - public static ExpressionMatcher expression(String expression) + public static ExpressionMatcher expression(@Language("SQL") String expression) { return new ExpressionMatcher(expression); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java index 7090947fe6f9..317b01a109aa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java @@ -127,7 +127,7 @@ private static void assertRewritten(String from, String to) return rewrite( PlanBuilder.expression(from), transactedSession, - PLANNER_CONTEXT.getMetadata(), + PLANNER_CONTEXT, TYPE_ANALYZER, TypeProvider.copyOf(ImmutableMap.builder() .put(new Symbol("x"), BIGINT) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java index f66d4c3487b6..aa0bcd876289 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressions.java @@ -25,7 +25,7 @@ public class TestCanonicalizeExpressions @Test public void testDoesNotFireForExpressionsInCanonicalForm() { - CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getMetadata(), tester().getTypeAnalyzer()); + CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext(), tester().getTypeAnalyzer()); tester().assertThat(canonicalizeExpressions.filterExpressionRewrite()) .on(p -> p.filter(FALSE_LITERAL, p.values())) .doesNotFire(); @@ -34,7 +34,7 @@ public void testDoesNotFireForExpressionsInCanonicalForm() @Test public void testDoesNotFireForUnfilteredJoin() { - CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getMetadata(), tester().getTypeAnalyzer()); + CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext(), tester().getTypeAnalyzer()); tester().assertThat(canonicalizeExpressions.joinExpressionRewrite()) .on(p -> p.join(INNER, p.values(), p.values())) .doesNotFire(); @@ -43,7 +43,7 @@ public void testDoesNotFireForUnfilteredJoin() @Test public void testDoesNotFireForCanonicalExpressions() { - CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getMetadata(), tester().getTypeAnalyzer()); + CanonicalizeExpressions canonicalizeExpressions = new CanonicalizeExpressions(tester().getPlannerContext(), tester().getTypeAnalyzer()); tester().assertThat(canonicalizeExpressions.joinExpressionRewrite()) .on(p -> p.join(INNER, p.values(), p.values(), FALSE_LITERAL)) .doesNotFire(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java index df2d120f1bb4..132a8c02902a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -69,7 +69,7 @@ public class TestEliminateCrossJoins @Test public void testEliminateCrossJoin() { - tester().assertThat(new EliminateCrossJoins(tester().getMetadata(), tester().getTypeAnalyzer())) + tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())) .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( @@ -85,7 +85,7 @@ public void testEliminateCrossJoin() @Test public void testRetainOutgoingGroupReferences() { - tester().assertThat(new EliminateCrossJoins(tester().getMetadata(), tester().getTypeAnalyzer())) + tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())) .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( @@ -99,7 +99,7 @@ public void testRetainOutgoingGroupReferences() @Test public void testDoNotReorderOuterJoin() { - tester().assertThat(new EliminateCrossJoins(tester().getMetadata(), tester().getTypeAnalyzer())) + tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())) .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(JoinNode.Type.LEFT)) .doesNotFire(); @@ -126,7 +126,7 @@ public void testJoinOrder() "a", "c", "b", "c"); - JoinGraph joinGraph = JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); + JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); assertEquals( getJoinOrder(joinGraph), @@ -158,7 +158,7 @@ public void testJoinOrderWithRealCrossJoin() PlanNode plan = joinNode(leftPlan, rightPlan); - JoinGraph joinGraph = JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); + JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); assertEquals( getJoinOrder(joinGraph), @@ -180,7 +180,7 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() "b1", "c1", "b2", "c2"); - JoinGraph joinGraph = JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); + JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); assertEquals( getJoinOrder(joinGraph), @@ -201,7 +201,7 @@ public void testDoesNotChangeOrderWithoutCrossJoin() values("c"), "b", "c"); - JoinGraph joinGraph = JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); + JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); assertEquals( getJoinOrder(joinGraph), @@ -221,7 +221,7 @@ public void testDoNotReorderCrossJoins() values("c"), "b", "c"); - JoinGraph joinGraph = JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); + JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()); assertEquals( getJoinOrder(joinGraph), @@ -231,7 +231,7 @@ public void testDoNotReorderCrossJoins() @Test public void testEliminateCrossJoinWithNonIdentityProjections() { - tester().assertThat(new EliminateCrossJoins(tester().getMetadata(), tester().getTypeAnalyzer())) + tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())) .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(p -> { Symbol a1 = p.symbol("a1"); @@ -308,7 +308,7 @@ public void testGiveUpOnComplexProjections() "a2", "c", "b", "c"); - assertEquals(JoinGraph.buildFrom(tester().getMetadata(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()).size(), 2); + assertEquals(JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()).size(), 2); } private Function crossJoinAndJoin(JoinNode.Type secondJoinType) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java index 4fef8b707819..76779bc9a242 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java @@ -16,14 +16,18 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.spi.type.RowType; +import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.tree.Literal; import org.testng.annotations.Test; +import java.util.Map; import java.util.Optional; +import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -37,7 +41,7 @@ public class TestInlineProjections @Test public void test() { - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.builder() @@ -83,10 +87,38 @@ public void test() values(ImmutableMap.of("x", 0, "msg", 1))))); } + /** + * Verify that non-{@link Literal} but literal-like constant expression gets inlined. + * + * @implNote The test uses decimals, as decimals values do not have direct literal form (see {@link LiteralEncoder}). + */ + @Test + public void testInlineEffectivelyLiteral() + { + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + // Use the literal-like expression multiple times. Single-use expression may be inlined regardless of whether it's a literal + .put(p.symbol("decimal_multiplication"), expression("decimal_literal * decimal_literal")) + .put(p.symbol("decimal_addition"), expression("decimal_literal + decimal_literal")) + .build(), + p.project(Assignments.builder() + .put(p.symbol("decimal_literal", createDecimalType(8, 4)), expression("CAST(DECIMAL '12.5' AS decimal(8,4))")) + .build(), + p.values(p.symbol("x"))))) + .matches( + project( + Map.of( + "decimal_multiplication", PlanMatchPattern.expression("CAST(DECIMAL '12.5' AS decimal(8, 4)) * CAST(DECIMAL '12.5' AS decimal(8, 4))"), + "decimal_addition", PlanMatchPattern.expression("CAST(DECIMAL '12.5' AS decimal(8, 4)) + CAST(DECIMAL '12.5' AS decimal(8, 4))")), + values(Map.of("x", 0)))); + } + @Test public void testEliminatesIdentityProjection() { - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.builder() @@ -108,7 +140,7 @@ public void testEliminatesIdentityProjection() public void testIdentityProjections() { // projection renaming symbol - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.of(p.symbol("output"), expression("value")), @@ -118,7 +150,7 @@ public void testIdentityProjections() .doesNotFire(); // identity projection - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.identity(p.symbol("x")), @@ -134,7 +166,7 @@ public void testIdentityProjections() @Test public void testSubqueryProjections() { - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value")), @@ -147,7 +179,7 @@ public void testSubqueryProjections() // ImmutableMap.of("fromOuterScope", PlanMatchPattern.expression("fromOuterScope"), "value", PlanMatchPattern.expression("value")), values(ImmutableMap.of("value", 0)))); - tester().assertThat(new InlineProjections(tester().getTypeAnalyzer())) + tester().assertThat(new InlineProjections(tester().getPlannerContext(), tester().getTypeAnalyzer())) .on(p -> p.project( Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value_1")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java index 994ed2765a16..f657882e8367 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -103,7 +103,7 @@ public void testDoesNotAllowOuterJoin() ImmutableList.of(a1), ImmutableList.of(b1), Optional.empty()); - assertThatThrownBy(() -> toMultiJoinNode(queryRunner.getMetadata(), outerJoin, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes())) + assertThatThrownBy(() -> toMultiJoinNode(queryRunner.getPlannerContext(), outerJoin, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes())) .isInstanceOf(IllegalStateException.class) .hasMessageMatching("join type must be.*"); } @@ -139,7 +139,7 @@ public void testDoesNotConvertNestedOuterJoins() .setOutputSymbols(a1, b1, c1) .build(); assertEquals( - toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), + toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), expected); } @@ -166,7 +166,7 @@ public void testPushesProjectionsThroughJoin() equiJoinClause(a, b))), valuesC, equiJoinClause(d, c)); - MultiJoinNode actual = toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, true, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()); + MultiJoinNode actual = toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, true, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()); assertEquals(actual.getOutputSymbols(), ImmutableList.of(d, c)); assertEquals(actual.getFilter(), and(createEqualsExpression(a, b), createEqualsExpression(d, c))); assertTrue(actual.isPushedProjectionThroughJoin()); @@ -212,7 +212,7 @@ public void testDoesNotPushStraddlingProjection() equiJoinClause(a, b))), valuesC, equiJoinClause(d, c)); - MultiJoinNode actual = toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, true, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()); + MultiJoinNode actual = toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, true, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()); assertEquals(actual.getOutputSymbols(), ImmutableList.of(d, c)); assertEquals(actual.getFilter(), createEqualsExpression(d, c)); assertFalse(actual.isPushedProjectionThroughJoin()); @@ -264,7 +264,7 @@ public void testRetainsOutputSymbols() .setOutputSymbols(a1, b1) .build(); assertEquals( - toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), + toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), expected); } @@ -310,7 +310,7 @@ public void testCombinesCriteriaAndFilters() ImmutableList.of(a1, b1, b2, c1, c2), false); assertEquals( - toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), + toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), expected); } @@ -368,7 +368,7 @@ public void testConvertsBushyTrees() .setOutputSymbols(a1, b1, c1, d1, d2, e1, e2) .build(); assertEquals( - toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, 5, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), + toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, 5, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), expected); } @@ -428,7 +428,7 @@ public void testMoreThanJoinLimit() .setOutputSymbols(a1, b1, c1, d1, d2, e1, e2) .build(); assertEquals( - toMultiJoinNode(queryRunner.getMetadata(), joinNode, noLookup(), planNodeIdAllocator, 2, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), + toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, 2, false, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes()), expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java index 431916aef21c..9665f5df3a7d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -88,7 +88,7 @@ a2, new ArithmeticUnaryExpression(PLUS, a0.toSymbolReference()), new JoinNode.EquiJoinClause(a1, b1))); Session session = testSessionBuilder().build(); - Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT.getMetadata(), planNode, noLookup(), idAllocator, session, createTestingTypeAnalyzer( + Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT, planNode, noLookup(), idAllocator, session, createTestingTypeAnalyzer( PLANNER_CONTEXT), p.getTypes()); assertTrue(rewritten.isPresent()); assertPlan( @@ -128,7 +128,7 @@ c, new ArithmeticBinaryExpression(ADD, a.toSymbolReference(), b.toSymbolReferenc INNER, p.values(a), p.values(b))); - Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT.getMetadata(), planNode, noLookup(), new PlanNodeIdAllocator(), testSessionBuilder().build(), createTestingTypeAnalyzer( + Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT, planNode, noLookup(), new PlanNodeIdAllocator(), testSessionBuilder().build(), createTestingTypeAnalyzer( PLANNER_CONTEXT), p.getTypes()); assertThat(rewritten).isEmpty(); } @@ -148,7 +148,7 @@ c, new ArithmeticUnaryExpression(MINUS, a.toSymbolReference())), LEFT, p.values(a), p.values(b))); - Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT.getMetadata(), planNode, noLookup(), new PlanNodeIdAllocator(), testSessionBuilder().build(), createTestingTypeAnalyzer( + Optional rewritten = pushProjectionThroughJoin(PLANNER_CONTEXT, planNode, noLookup(), new PlanNodeIdAllocator(), testSessionBuilder().build(), createTestingTypeAnalyzer( PLANNER_CONTEXT), p.getTypes()); assertThat(rewritten).isEmpty(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java index 6d5f9796a4b9..4d62c80abc24 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReorderJoins.java @@ -679,6 +679,6 @@ public void testReorderAndReplicate() private RuleAssert assertReorderJoins() { - return tester.assertThat(new ReorderJoins(PLANNER_CONTEXT.getMetadata(), new CostComparator(1, 1, 1), createTestingTypeAnalyzer(PLANNER_CONTEXT))); + return tester.assertThat(new ReorderJoins(PLANNER_CONTEXT, new CostComparator(1, 1, 1), createTestingTypeAnalyzer(PLANNER_CONTEXT))); } } diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 23714198804d..e9bf67b8097b 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -18,6 +18,14 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; +import io.trino.FeaturesConfig; +import io.trino.client.NodeVersion; +import io.trino.metadata.BlockEncodingManager; +import io.trino.metadata.CatalogManager; +import io.trino.metadata.DisabledSystemSecurityMetadata; +import io.trino.metadata.InternalBlockEncodingSerde; +import io.trino.metadata.MetadataManager; +import io.trino.metadata.TypeRegistry; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; @@ -26,7 +34,11 @@ import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.LiteralEncoder; +import io.trino.sql.tree.Expression; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; @@ -41,6 +53,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.airlift.testing.Assertions.assertInstanceOf; +import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.block.BlockSerdeUtil.writeBlock; import static io.trino.operator.OperatorAssertion.toRow; import static io.trino.spi.connector.SortOrder.ASC_NULLS_FIRST; @@ -52,7 +65,10 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.type.TypeUtils.readNativeValue; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; import static io.trino.util.StructuralTestUtil.arrayBlockOf; import static io.trino.util.StructuralTestUtil.mapBlockOf; import static java.lang.String.format; @@ -62,6 +78,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public abstract class AbstractTestType { @@ -140,6 +157,48 @@ else if (type.getJavaType() == Slice.class) { return nullsBlockBuilder.build(); } + @Test + public void testLiteralFormRecognized() + { + PlannerContext plannerContext = createPlannerContext(); + LiteralEncoder literalEncoder = new LiteralEncoder(plannerContext); + for (int position = 0; position < testBlock.getPositionCount(); position++) { + Object value = readNativeValue(type, testBlock, position); + Expression expression = literalEncoder.toExpression(TEST_SESSION, value, type); + if (!isEffectivelyLiteral(plannerContext, TEST_SESSION, expression)) { + fail(format( + "Expression not recognized literal for value %s at position %s (%s): %s", + value, + position, + type.getObjectValue(SESSION, testBlock, position), + expression)); + } + } + } + + protected PlannerContext createPlannerContext() + { + TypeRegistry typeRegistry = new TypeRegistry(new TypeOperators(), new FeaturesConfig()); + typeRegistry.addType(type); + + TypeManager typeManager = new InternalTypeManager(typeRegistry); + TypeOperators typeOperators = new TypeOperators(); + MetadataManager metadata = new MetadataManager( + new FeaturesConfig(), + new DisabledSystemSecurityMetadata(), + createTestTransactionManager(new CatalogManager()), + typeOperators, + new BlockTypeOperators(typeOperators), + typeManager, + new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager), + NodeVersion.UNKNOWN); + return new PlannerContext( + metadata, + new TypeOperators(), + new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager), + typeManager); + } + @Test public void testBlock() { diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java index a19def637a72..63510c5709bc 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java @@ -38,10 +38,12 @@ import java.util.Base64; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; import static io.trino.SystemSessionProperties.SPATIAL_PARTITIONING_TABLE_NAME; import static io.trino.geospatial.KdbTree.Node.newLeaf; import static io.trino.metadata.LiteralFunction.LITERAL_FUNCTION_NAME; +import static io.trino.plugin.geospatial.GeoFunctions.stPoint; import static io.trino.spi.StandardErrorCode.INVALID_SPATIAL_PARTITIONING; import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; @@ -58,6 +60,8 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.unnest; import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.lang.Math.cos; +import static java.lang.Math.toRadians; import static java.lang.String.format; import static java.util.Collections.emptyList; import static org.testng.Assert.assertEquals; @@ -69,6 +73,7 @@ public class TestSpatialJoinPlanning private static final String KDB_TREE_JSON = KdbTreeUtils.toJson(new KdbTree(newLeaf(new Rectangle(0, 0, 10, 10), 0))); private String kdbTreeLiteral; + private String point21x21Literal; @Override protected LocalQueryRunner createLocalQueryRunner() @@ -93,6 +98,7 @@ public void setUp() DynamicSliceOutput output = new DynamicSliceOutput(0); BlockSerdeUtil.writeBlock(new TestingBlockEncodingSerde(), output, block); kdbTreeLiteral = format("\"%s\"(from_base64('%s'))", LITERAL_FUNCTION_NAME, Base64.getEncoder().encodeToString(output.slice().getBytes())); + point21x21Literal = format("\"%s\"(from_base64('%s'))", LITERAL_FUNCTION_NAME, Base64.getEncoder().encodeToString(stPoint(2.1, 2.1).getBytes())); } @Test @@ -301,30 +307,26 @@ public void testDistanceQuery() "WHERE ST_Distance(ST_Point(a.lng, a.lat), ST_Point(b.lng, b.lat)) <= 3.1", anyTree( spatialJoin("st_distance(st_point_a, st_point_b) <= radius", - project(ImmutableMap.of("st_point_a", expression("ST_Point(cast(a_lng as double), cast(a_lng as double))"), "a_name", expression("'x'")), - project( - ImmutableMap.of("a_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow())), + project( + ImmutableMap.of("st_point_a", expression(point21x21Literal), "a_name", expression("'x'")), + singleRow()), any( - project(ImmutableMap.of("st_point_b", expression("ST_Point(cast(b_lng as double), cast(b_lng as double))"), "radius", expression("3.1e0"), "b_name", expression("'x'")), - project( - ImmutableMap.of("b_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow())))))); + project( + ImmutableMap.of("st_point_b", expression(point21x21Literal), "radius", expression("3.1e0"), "b_name", expression("'x'")), + singleRow()))))); assertPlan("SELECT b.name, a.name " + "FROM " + singleRow("2.1", "2.1", "'x'") + " AS a (lng, lat, name), " + singleRow("2.1", "2.1", "'x'") + " AS b (lng, lat, name) " + "WHERE ST_Distance(ST_Point(a.lng, a.lat), ST_Point(b.lng, b.lat)) <= 300 / (cos(radians(b.lat)) * 111321)", anyTree( spatialJoin("st_distance(st_point_a, st_point_b) <= radius", - project(ImmutableMap.of("st_point_a", expression("ST_Point(cast(a_lng as double), cast(a_lng as double))"), "a_name", expression("'x'")), - project( - ImmutableMap.of("a_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow())), + project( + ImmutableMap.of("st_point_a", expression(point21x21Literal), "a_name", expression("'x'")), + singleRow()), any( - project(ImmutableMap.of("st_point_b", expression("ST_Point(cast(b_lng as double), cast(b_lng as double))"), "radius", expression("3e2 / (cos(radians(cast(b_lng as double))) * 111.321e3)"), "b_name", expression("'x'")), - project( - ImmutableMap.of("b_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow())))))); + project( + ImmutableMap.of("st_point_b", expression(point21x21Literal), "radius", expression(doubleLiteral(3e2 / (cos(toRadians(2.1)) * 111.321e3))), "b_name", expression("'x'")), + singleRow()))))); // distributed assertDistributedPlan("SELECT b.name, a.name " + @@ -335,18 +337,19 @@ public void testDistanceQuery() spatialJoin("st_distance(st_point_a, st_point_b) <= radius", Optional.of(KDB_TREE_JSON), anyTree( unnest( - project(ImmutableMap.of("partitions", expression(format("spatial_partitions(%s, st_point_a)", kdbTreeLiteral))), - project(ImmutableMap.of("st_point_a", expression("ST_Point(cast(a_lng as double), cast(a_lng as double))")), - project( - ImmutableMap.of("a_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow()))))), + project( + ImmutableMap.of( + "st_point_a", expression(point21x21Literal), + "partitions", expression(format("spatial_partitions(%s, %s)", kdbTreeLiteral, point21x21Literal))), + singleRow()))), anyTree( unnest( - project(ImmutableMap.of("partitions", expression(format("spatial_partitions(%s, st_point_b, 3.1e0)", kdbTreeLiteral)), "radius", expression("3.1e0")), - project(ImmutableMap.of("st_point_b", expression("ST_Point(cast(b_lng as double), cast(b_lng as double))")), - project( - ImmutableMap.of("b_lng", expression("CAST(DECIMAL '2.1' AS decimal(2, 1))")), - singleRow())))))))); + project( + ImmutableMap.of( + "st_point_b", expression(point21x21Literal), + "partitions", expression(format("spatial_partitions(%s, %s, 3.1e0)", kdbTreeLiteral, point21x21Literal)), + "radius", expression("3.1e0")), + singleRow())))))); } @Test @@ -538,4 +541,10 @@ private Session withSpatialPartitioning(String tableName) .setSystemProperty(SPATIAL_PARTITIONING_TABLE_NAME, tableName) .build(); } + + private static String doubleLiteral(double value) + { + checkArgument(Double.isFinite(value)); + return format("%.16E", value); + } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java index aa83778df633..9cf81c302136 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java @@ -225,7 +225,8 @@ public void testCountColumn() assertQuery("SELECT COUNT(1) FROM orders"); assertQuery("SELECT COUNT(NULLIF(orderstatus, 'F')) FROM orders"); - assertQuery("SELECT COUNT(CAST(NULL AS BIGINT)) FROM orders"); // todo: make COUNT(null) work + assertQuery("SELECT COUNT(NULL) FROM orders", "VALUES 0"); + assertQuery("SELECT COUNT(CAST(NULL AS BIGINT)) FROM orders", "VALUES 0"); } @Test