diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java index 740cad9318733..cb493055f09fc 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java @@ -88,7 +88,7 @@ public final class HiveSessionProperties public static final String COLLECT_COLUMN_STATISTICS_ON_WRITE = "collect_column_statistics_on_write"; private static final String OPTIMIZE_MISMATCHED_BUCKET_COUNT = "optimize_mismatched_bucket_count"; private static final String S3_SELECT_PUSHDOWN_ENABLED = "s3_select_pushdown_enabled"; - private static final String SHUFFLE_PARTITIONED_COLUMNS_FOR_TABLE_WRITE = "shuffle_partitioned_columns_for_table_write"; + public static final String SHUFFLE_PARTITIONED_COLUMNS_FOR_TABLE_WRITE = "shuffle_partitioned_columns_for_table_write"; private static final String TEMPORARY_STAGING_DIRECTORY_ENABLED = "temporary_staging_directory_enabled"; private static final String TEMPORARY_STAGING_DIRECTORY_PATH = "temporary_staging_directory_path"; private static final String TEMPORARY_TABLE_SCHEMA = "temporary_table_schema"; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java index 51bbbd73e253f..52780f236687a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java @@ -58,6 +58,7 @@ import java.util.Set; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_METADATA_QUERIES; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.predicate.Domain.multipleValues; import static com.facebook.presto.common.predicate.Domain.notNull; @@ -73,10 +74,12 @@ import static com.facebook.presto.hive.HiveSessionProperties.COLLECT_COLUMN_STATISTICS_ON_WRITE; import static com.facebook.presto.hive.HiveSessionProperties.PUSHDOWN_FILTER_ENABLED; import static com.facebook.presto.hive.HiveSessionProperties.RANGE_FILTERS_ON_SUBSCRIPTS_ENABLED; +import static com.facebook.presto.hive.HiveSessionProperties.SHUFFLE_PARTITIONED_COLUMNS_FOR_TABLE_WRITE; import static com.facebook.presto.hive.TestHiveIntegrationSmokeTest.assertRemoteExchangesCount; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; @@ -277,6 +280,91 @@ public void testPartitionPruning() } } + @Test + public void testMetadataAggregationFolding() + { + QueryRunner queryRunner = getQueryRunner(); + Session optimizeMetadataQueries = Session.builder(this.getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZE_METADATA_QUERIES, Boolean.toString(true)) + .build(); + Session shufflePartitionColumns = Session.builder(this.getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty(HIVE_CATALOG, SHUFFLE_PARTITIONED_COLUMNS_FOR_TABLE_WRITE, Boolean.toString(true)) + .build(); + + queryRunner.execute( + shufflePartitionColumns, + "CREATE TABLE test_metadata_aggregation_folding WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, CAST(to_iso8601(date_add('DAY', orderkey % 7, date('2020-07-01'))) AS VARCHAR) AS ds FROM orders WHERE orderkey < 1000"); + queryRunner.execute( + shufflePartitionColumns, + "CREATE TABLE test_metadata_aggregation_folding_more_partitions WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, CAST(to_iso8601(date_add('DAY', orderkey % 200, date('2020-07-01'))) AS VARCHAR) AS ds FROM orders WHERE orderkey < 1000"); + queryRunner.execute( + shufflePartitionColumns, + "CREATE TABLE test_metadata_aggregation_folding_null_partitions WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, CAST(to_iso8601(date_add('DAY', orderkey % 7, date('2020-07-01'))) AS VARCHAR) AS ds FROM orders WHERE orderkey < 1000"); + queryRunner.execute( + shufflePartitionColumns, + "INSERT INTO test_metadata_aggregation_folding_null_partitions SELECT 0 as orderkey, null AS ds"); + + try { + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding WHERE ds = (SELECT max(ds) from test_metadata_aggregation_folding)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding", getSingleValueColumnDomain("ds", "2020-07-07"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding WHERE ds = (SELECT min(ds) from test_metadata_aggregation_folding)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding", getSingleValueColumnDomain("ds", "2020-07-01"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding_more_partitions WHERE ds = (SELECT max(ds) from test_metadata_aggregation_folding_more_partitions)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding_more_partitions", getSingleValueColumnDomain("ds", "2021-01-16"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding_more_partitions WHERE ds = (SELECT min(ds) from test_metadata_aggregation_folding_more_partitions)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding_more_partitions", getSingleValueColumnDomain("ds", "2020-07-01"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding WHERE ds = (SELECT max(ds) from test_metadata_aggregation_folding_null_partitions)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding", getSingleValueColumnDomain("ds", "2020-07-07"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + assertPlan( + optimizeMetadataQueries, + "SELECT * FROM test_metadata_aggregation_folding WHERE ds = (SELECT min(ds) from test_metadata_aggregation_folding_null_partitions)", + anyTree( + join(INNER, ImmutableList.of(), + tableScan("test_metadata_aggregation_folding", getSingleValueColumnDomain("ds", "2020-07-01"), TRUE_CONSTANT, ImmutableSet.of("ds")), + anyTree(any())))); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_metadata_aggregation_folding"); + queryRunner.execute("DROP TABLE IF EXISTS test_metadata_aggregation_folding_more_partitions"); + queryRunner.execute("DROP TABLE IF EXISTS test_metadata_aggregation_folding_null_partitions"); + } + } + + private static TupleDomain getSingleValueColumnDomain(String column, String value) + { + return withColumnDomains(ImmutableMap.of(column, singleValue(VARCHAR, utf8Slice(value)))); + } + private static List utf8Slices(String... values) { return Arrays.stream(values).map(Slices::utf8Slice).collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index eef609a427ccc..b6bb10def5f22 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -450,7 +450,7 @@ public SystemSessionProperties( Duration::toString), booleanProperty( OPTIMIZE_METADATA_QUERIES, - "Enable optimization for metadata queries", + "Enable optimization for metadata queries. Note if metadata entry has empty data, the result might be different (e.g. empty Hive partition)", featuresConfig.isOptimizeMetadataQueries(), false), integerProperty( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index ed438f0c46a93..d171915996b17 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -624,6 +624,7 @@ public boolean isOptimizeMetadataQueries() } @Config("optimizer.optimize-metadata-queries") + @ConfigDescription("Enable optimization for metadata queries. Note if metadata entry has empty data, the result might be different (e.g. empty Hive partition)") public FeaturesConfig setOptimizeMetadataQueries(boolean optimizeMetadataQueries) { this.optimizeMetadataQueries = optimizeMetadataQueries; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index 929e9ad431bba..00b87f68ef6a9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -160,6 +161,15 @@ public Expression visitExchange(ExchangeNode node, Void context) }); } + @Override + public Expression visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + if (node.getSource() instanceof ProjectNode) { + return node.getSource().accept(this, context); + } + return TRUE_LITERAL; + } + @Override public Expression visitProject(ProjectNode node, Void context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 9590645171e20..251d41f0dbc17 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -410,20 +410,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantIdentityProjections())), - new MetadataQueryOptimizer(metadata), - new IterativeOptimizer( - ruleStats, - statsCalculator, - estimatedExchangesCostCalculator, - ImmutableSet.of(new EliminateCrossJoins())), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again - predicatePushDown, - simplifyOptimizer, // Should be always run after PredicatePushDown - new IterativeOptimizer( - ruleStats, - statsCalculator, - estimatedExchangesCostCalculator, - new PickTableLayout(metadata, sqlParser).rules())); + ImmutableSet.of(new RemoveRedundantIdentityProjections()))); // TODO: move this before optimization if possible!! // Replace all expressions with row expressions @@ -434,6 +421,24 @@ public PlanOptimizers( new TranslateExpressions(metadata, sqlParser).rules())); // After this point, all planNodes should not contain OriginalExpression + builder.add(new MetadataQueryOptimizer(metadata)); + + // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + builder.add( + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new EliminateCrossJoins())), + rowExpressionPredicatePushDown, + simplifyRowExpressionOptimizer); // Should always run simplifyOptimizer after rowExpressionPredicatePushDown + + builder.add(new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + new PickTableLayout(metadata, sqlParser).rules())); + // PlanRemoteProjections only handles RowExpression so this need to run after TranslateExpressions // Rules applied after this need to handle locality of ProjectNode properly. builder.add(new IterativeOptimizer( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java index 19bb7df1037ee..2ef97d6aa5f27 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java @@ -31,6 +31,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -156,6 +157,15 @@ public RowExpression visitExchange(ExchangeNode node, Void context) }); } + @Override + public RowExpression visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + if (node.getSource() instanceof ProjectNode) { + return node.getSource().accept(this, context); + } + return TRUE_CONSTANT; + } + @Override public RowExpression visitProject(ProjectNode node, Void context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 7d134559a1a46..0ab153b9c6e7b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -22,13 +22,12 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.JoinNode; -import com.facebook.presto.sql.relational.OriginalExpressionUtils; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -45,11 +44,9 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.plan.Patterns.join; -import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Maps.transformValues; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -192,24 +189,24 @@ public static PlanNode buildJoinTree(List expectedO Optional.empty()); } - List filters = graph.getFilters(); + List filters = graph.getFilters(); - for (Expression filter : filters) { + for (RowExpression filter : filters) { result = new FilterNode( idAllocator.getNextId(), result, - castToRowExpression(filter)); + filter); } if (graph.getAssignments().isPresent()) { result = new ProjectNode( idAllocator.getNextId(), result, - Assignments.copyOf(transformValues(graph.getAssignments().get(), OriginalExpressionUtils::castToRowExpression))); + Assignments.copyOf(graph.getAssignments().get())); } // If needed, introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) - return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputVariables), false).orElse(result); + return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputVariables), true).orElse(result); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index 1472bfe0a7c6d..b4d7fda80a166 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -18,14 +18,18 @@ import com.facebook.presto.common.function.QualifiedFunctionName; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.DiscretePredicates; +import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; @@ -37,27 +41,29 @@ import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator; -import com.facebook.presto.sql.planner.LiteralEncoder; import com.facebook.presto.sql.planner.PlanVariableAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.SortNode; -import com.facebook.presto.sql.relational.OriginalExpressionUtils; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import static com.facebook.presto.metadata.BuiltInFunctionNamespaceManager.DEFAULT_NAMESPACE; +import static com.facebook.presto.sql.planner.RowExpressionInterpreter.evaluateConstantRowExpression; +import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; /** * Converts cardinality-insensitive aggregations (max, min, "distinct") over partition keys @@ -71,15 +77,18 @@ public class MetadataQueryOptimizer QualifiedFunctionName.of(DEFAULT_NAMESPACE, "min"), QualifiedFunctionName.of(DEFAULT_NAMESPACE, "approx_distinct")); + // Min/Max could be folded into LEAST/GREATEST + private static final Map AGGREGATION_SCALAR_MAPPING = ImmutableMap.of( + QualifiedFunctionName.of(DEFAULT_NAMESPACE, "max"), QualifiedFunctionName.of(DEFAULT_NAMESPACE, "greatest"), + QualifiedFunctionName.of(DEFAULT_NAMESPACE, "min"), QualifiedFunctionName.of(DEFAULT_NAMESPACE, "least")); + private final Metadata metadata; - private final LiteralEncoder literalEncoder; public MetadataQueryOptimizer(Metadata metadata) { requireNonNull(metadata, "metadata is null"); this.metadata = metadata; - this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); } @Override @@ -88,7 +97,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Pla if (!SystemSessionProperties.isOptimizeMetadataQueries(session)) { return plan; } - return SimplePlanRewriter.rewriteWith(new Optimizer(session, metadata, literalEncoder, idAllocator), plan, null); + return SimplePlanRewriter.rewriteWith(new Optimizer(session, metadata, idAllocator), plan, null); } private static class Optimizer @@ -97,14 +106,14 @@ private static class Optimizer private final PlanNodeIdAllocator idAllocator; private final Session session; private final Metadata metadata; - private final LiteralEncoder literalEncoder; + private final RowExpressionDeterminismEvaluator determinismEvaluator; - private Optimizer(Session session, Metadata metadata, LiteralEncoder literalEncoder, PlanNodeIdAllocator idAllocator) + private Optimizer(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator) { this.session = session; this.metadata = metadata; - this.literalEncoder = literalEncoder; this.idAllocator = idAllocator; + this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); } @Override @@ -118,7 +127,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont } } - Optional result = findTableScan(node.getSource()); + Optional result = findTableScan(node.getSource(), determinismEvaluator); if (!result.isPresent()) { return context.defaultRewrite(node); } @@ -156,34 +165,124 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont return context.defaultRewrite(node); } + if (isReducible(node)) { + // Fold min/max aggregations to a constant value + return reduce(node, inputs, columns, context, predicates); + } + ImmutableList.Builder> rowsBuilder = ImmutableList.builder(); for (TupleDomain domain : predicates.getPredicates()) { - if (!domain.isNone()) { - Map entries = TupleDomain.extractFixedValues(domain).get(); + if (domain.isNone()) { + continue; + } + Map entries = TupleDomain.extractFixedValues(domain).get(); - ImmutableList.Builder rowBuilder = ImmutableList.builder(); - // for each input column, add a literal expression using the entry value - for (VariableReferenceExpression input : inputs) { - ColumnHandle column = columns.get(input); - NullableValue value = entries.get(column); - if (value == null) { - // partition key does not have a single value, so bail out to be safe - return context.defaultRewrite(node); - } - else { - rowBuilder.add(constant(value.getValue(), input.getType())); - } + ImmutableList.Builder rowBuilder = ImmutableList.builder(); + // for each input column, add a literal expression using the entry value + for (VariableReferenceExpression input : inputs) { + ColumnHandle column = columns.get(input); + NullableValue value = entries.get(column); + if (value == null) { + // partition key does not have a single value, so bail out to be safe + return context.defaultRewrite(node); + } + else { + rowBuilder.add(constant(value.getValue(), input.getType())); } - rowsBuilder.add(rowBuilder.build()); } + rowsBuilder.add(rowBuilder.build()); } // replace the tablescan node with a values node - ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), inputs, rowsBuilder.build()); - return SimplePlanRewriter.rewriteWith(new Replacer(valuesNode), node); + return SimplePlanRewriter.rewriteWith(new Replacer(new ValuesNode(idAllocator.getNextId(), inputs, rowsBuilder.build())), node); + } + + private boolean isReducible(AggregationNode node) + { + if (node.getAggregations().isEmpty() || !(node.getSource() instanceof TableScanNode)) { + return false; + } + for (Aggregation aggregation : node.getAggregations().values()) { + FunctionMetadata functionMetadata = metadata.getFunctionManager().getFunctionMetadata(aggregation.getFunctionHandle()); + if (!AGGREGATION_SCALAR_MAPPING.containsKey(functionMetadata.getName()) || functionMetadata.getArgumentTypes().size() > 1) { + return false; + } + } + return true; + } + + private PlanNode reduce( + AggregationNode node, + List inputs, + Map columns, + RewriteContext context, + DiscretePredicates predicates) + { + // Fold min/max aggregations to a constant value + ImmutableList.Builder scalarsBuilder = ImmutableList.builder(); + for (int i = 0; i < inputs.size(); i++) { + ImmutableList.Builder arguments = ImmutableList.builder(); + ColumnHandle column = columns.get(inputs.get(i)); + // for each input column, add a literal expression using the entry value + for (TupleDomain domain : predicates.getPredicates()) { + if (domain.isNone()) { + continue; + } + Map entries = TupleDomain.extractFixedValues(domain).get(); + NullableValue value = entries.get(column); + if (value == null) { + // partition key does not have a single value, so bail out to be safe + return context.defaultRewrite(node); + } + // min/max ignores null value + else if (value.getValue() != null) { + Type type = inputs.get(i).getType(); + arguments.add(constant(value.getValue(), type)); + } + } + scalarsBuilder.add(evaluateMinMax( + metadata.getFunctionManager().getFunctionMetadata(node.getAggregations().get(node.getOutputVariables().get(i)).getFunctionHandle()), + arguments.build())); + } + List scalars = scalarsBuilder.build(); + + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + assignments.put(node.getOutputVariables().get(i), scalars.get(i)); + } + ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), inputs, ImmutableList.of(scalars)); + return new ProjectNode(idAllocator.getNextId(), valuesNode, assignments.build()); + } + + private RowExpression evaluateMinMax(FunctionMetadata aggregationFunctionMetadata, List arguments) + { + Type returnType = metadata.getTypeManager().getType(aggregationFunctionMetadata.getReturnType()); + if (arguments.isEmpty()) { + return constant(null, returnType); + } + + String scalarFunctionName = AGGREGATION_SCALAR_MAPPING.get(aggregationFunctionMetadata.getName()).getFunctionName(); + ConnectorSession connectorSession = session.toConnectorSession(); + while (arguments.size() > 1) { + List reducedArguments = new ArrayList<>(); + // We fold for every 100 values because GREATEST/LEAST has argument count limit + for (List partitionedArguments : Lists.partition(arguments, 100)) { + Object reducedValue = evaluateConstantRowExpression( + call( + metadata.getFunctionManager(), + scalarFunctionName, + returnType, + partitionedArguments), + metadata, + connectorSession); + reducedArguments.add(constant(reducedValue, returnType)); + } + arguments = reducedArguments; + } + return getOnlyElement(arguments); } - private static Optional findTableScan(PlanNode source) + private static Optional findTableScan(PlanNode source, RowExpressionDeterminismEvaluator determinismEvaluator) { while (true) { // allow any chain of linear transformations @@ -197,7 +296,7 @@ private static Optional findTableScan(PlanNode source) else if (source instanceof ProjectNode) { // verify projections are deterministic ProjectNode project = (ProjectNode) source; - if (!Iterables.all(project.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toList()), ExpressionDeterminismEvaluator::isDeterministic)) { + if (!Iterables.all(project.getAssignments().getExpressions(), determinismEvaluator::isDeterministic)) { return Optional.empty(); } source = project.getSource(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java index 4a2211c846a00..73b263b489d60 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java @@ -17,13 +17,12 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; -import com.facebook.presto.sql.relational.OriginalExpressionUtils; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; @@ -36,11 +35,9 @@ import java.util.Optional; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; -import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Maps.transformValues; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -51,8 +48,8 @@ */ public class JoinGraph { - private final Optional> assignments; - private final List filters; + private final Optional> assignments; + private final List filters; private final List nodes; // nodes in order of their appearance in tree plan (left, right, parent) private final Multimap edges; private final PlanNodeId rootId; @@ -93,8 +90,8 @@ public JoinGraph( List nodes, Multimap edges, PlanNodeId rootId, - List filters, - Optional> assignments) + List filters, + Optional> assignments) { this.nodes = nodes; this.edges = edges; @@ -103,26 +100,26 @@ public JoinGraph( this.assignments = assignments; } - public JoinGraph withAssignments(Map assignments) + public JoinGraph withAssignments(Map assignments) { return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments)); } - public Optional> getAssignments() + public Optional> getAssignments() { return assignments; } - public JoinGraph withFilter(Expression expression) + public JoinGraph withFilter(RowExpression expression) { - ImmutableList.Builder filters = ImmutableList.builder(); + ImmutableList.Builder filters = ImmutableList.builder(); filters.addAll(this.filters); filters.add(expression); return new JoinGraph(nodes, edges, rootId, filters.build(), assignments); } - public List getFilters() + public List getFilters() { return filters; } @@ -200,7 +197,7 @@ private JoinGraph joinWith(JoinGraph other, List joinCl .putAll(this.edges) .putAll(other.edges); - List joinedFilters = ImmutableList.builder() + List joinedFilters = ImmutableList.builder() .addAll(this.filters) .addAll(other.filters) .build(); @@ -256,7 +253,7 @@ public JoinGraph visitPlan(PlanNode node, Context context) public JoinGraph visitFilter(FilterNode node, Context context) { JoinGraph graph = node.getSource().accept(this, context); - return graph.withFilter(castToExpression(node.getPredicate())); + return graph.withFilter(node.getPredicate()); } @Override @@ -273,7 +270,7 @@ public JoinGraph visitJoin(JoinNode node, Context context) JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId()); if (node.getFilter().isPresent()) { - return graph.withFilter(castToExpression(node.getFilter().get())); + return graph.withFilter(node.getFilter().get()); } return graph; } @@ -283,7 +280,7 @@ public JoinGraph visitProject(ProjectNode node, Context context) { if (isIdentity(node)) { JoinGraph graph = node.getSource().accept(this, context); - return graph.withAssignments(transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression)); + return graph.withAssignments(node.getAssignments().getMap()); } return visitPlan(node, context); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 782fc0cd8d718..83f6d19489c05 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -471,6 +471,19 @@ public void testPushDownJoinConditionConjunctsToInnerSideBasedOnInheritedPredica "REGION_REGIONKEY", "regionkey"))))))); } + @Test + public void testScalarSubqueryJoinFilterPushdown() + { + assertPlan( + "SELECT * FROM orders WHERE orderkey = (SELECT 1)", + anyTree( + join(INNER, ImmutableList.of(), + filter("orderkey = BIGINT '1'", + tableScan("orders", ImmutableMap.of("orderkey", "orderkey"))), + anyTree( + project(ImmutableMap.of("orderkey", expression("1")), any()))))); + } + @Test public void testSameScalarSubqueryIsAppliedOnlyOnce() { @@ -1103,9 +1116,9 @@ public void testJoinNullFilters() LEFT, ImmutableList.of(equiJoinClause("NATION_REGIONKEY", "REGION_REGIONKEY")), anyTree( - tableScan( - "nation", - ImmutableMap.of("NATION_REGIONKEY", "regionkey"))), + tableScan( + "nation", + ImmutableMap.of("NATION_REGIONKEY", "regionkey"))), anyTree( filter("region_REGIONKEY IS NOT NULL", tableScan( @@ -1128,6 +1141,6 @@ public void testJoinNullFilters() tableScan( "region", ImmutableMap.of( - "REGION_REGIONKEY", "regionkey")))))); + "REGION_REGIONKEY", "regionkey")))))); } }