diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index d444edbee2648..74d182887b03d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -44,6 +44,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.ComparisonExpression; @@ -511,6 +512,25 @@ public static boolean containsSystemTableScan(PlanNode plan, Lookup lookup) .matches(); } + /// Checks whether a node is directly on top of a system table scan without exchange in between + public static boolean directlyOnSystemTableScan(PlanNode plan, Lookup lookup) + { + plan = lookup.resolve(plan); + for (PlanNode source : plan.getSources()) { + source = lookup.resolve(source); + if (source instanceof TableScanNode && isInternalSystemConnector(((TableScanNode) source).getTable().getConnectorId())) { + return true; + } + if (source instanceof ExchangeNode) { + continue; + } + if (directlyOnSystemTableScan(source, lookup)) { + return true; + } + } + return false; + } + public static boolean isConstant(RowExpression expression, Type type, Object value) { return expression instanceof ConstantExpression && diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 2709fabac6c6e..412e35666ddc8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -35,7 +35,7 @@ public interface Lookup default PlanNode resolve(PlanNode node) { if (node instanceof GroupReference) { - return resolveGroup(node).collect(toOptional()).get(); + return resolveGroup(node).collect(toOptional()).orElse(node); } return node; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 3c5f12fa3486c..07af1c7121d7a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -33,6 +33,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; +import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -60,7 +61,6 @@ import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW; import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC; import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER; -import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; @@ -166,8 +166,13 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context return Result.empty(); } - // System table scan must be run in Java on coordinator and partial aggregation output may not be compatible with Velox - if (nativeExecution && containsSystemTableScan(exchangeNode, context.getLookup())) { + // For native execution: + // Partial aggregation result from Java coordinator task is not compatible with native worker. + // System table scan must be run in on coordinator and addExchange would always add a GatherExchange on top of it. + // We should never push partial aggregation past the GatherExchange. + if (nativeExecution + && exchangeNode.getType() == GATHER + && PlannerUtils.directlyOnSystemTableScan(exchangeNode, context.getLookup())) { return Result.empty(); } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java index 29bddb238025d..41b295c25abfe 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java @@ -25,6 +25,8 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; @@ -33,6 +35,7 @@ import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; @@ -71,10 +74,13 @@ import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createSupplier; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createTableToTestHiddenColumns; -import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR; import static com.facebook.presto.spi.plan.ExchangeEncoding.ROW_WISE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.GroupingSetDescriptor; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; @@ -1520,66 +1526,72 @@ public void testSystemTables() "AS " + "SELECT nationkey, name, comment, regionkey FROM nation", tableName)); - String filter = format("SELECT regionkey FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName); + String groupingSet = format("SELECT count(*) FROM \"%s\" GROUP BY GROUPING SETS ((regionkey), ())", partitionsTableName); assertPlan( - filter, - anyTree( - exchange(REMOTE_STREAMING, GATHER, - filter( - "REGION_KEY % 3 = 1", - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))); - assertQuery(filter); - - String project = format("SELECT regionkey + 1 FROM \"%s\"", partitionsTableName); - assertPlan( - project, - anyTree( - exchange(REMOTE_STREAMING, GATHER, - project( - ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")), - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))); - assertQuery(project); - - String filterProject = format("SELECT regionkey + 1 FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName); - assertPlan( - filterProject, - anyTree( - exchange(REMOTE_STREAMING, GATHER, - project( - ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")), - filter( - "REGION_KEY % 3 = 1", - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); - assertQuery(filterProject); + groupingSet, + PlanMatchPattern.output(project( + aggregation( + new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)), + ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of(anySymbol()))), + ImmutableMap.of(), + Optional.of(new Symbol("groupid")), + FINAL, + exchange(LOCAL, REPARTITION, + aggregation( + new GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)), + ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of())), + ImmutableList.of(), + ImmutableMap.of(), + Optional.of(new Symbol("groupid")), + PARTIAL, + PlanMatchPattern.groupingSet( + ImmutableList.of(ImmutableList.of("REGION_KEY"), ImmutableList.of()), + ImmutableMap.of(), + "groupid", + ImmutableMap.of("regionkey$gid", expression("REGION_KEY")), + exchange(REMOTE_STREAMING, GATHER, + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))))); String aggregation = format("SELECT count(*), sum(regionkey) FROM \"%s\"", partitionsTableName); assertPlan( aggregation, - anyTree( + PlanMatchPattern.output( aggregation( ImmutableMap.of( - "FINAL_COUNT", functionCall("count", ImmutableList.of()), - "FINAL_SUM", functionCall("sum", ImmutableList.of("REGION_KEY"))), - SINGLE, + "FINAL_COUNT", functionCall("count", false, ImmutableList.of(anySymbol())), + "FINAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))), + FINAL, exchange(LOCAL, GATHER, - exchange(REMOTE_STREAMING, GATHER, - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); + aggregation( + ImmutableMap.of( + "PARTIAL_COUNT", functionCall("count", false, ImmutableList.of()), + "PARTIAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))), + PARTIAL, + exchange(REMOTE_STREAMING, GATHER, + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))); assertQuery(aggregation); String groupBy = format("SELECT regionkey, count(*) FROM \"%s\" GROUP BY regionkey", partitionsTableName); assertPlan( groupBy, - anyTree( + PlanMatchPattern.output( aggregation( singleGroupingSet("REGION_KEY"), ImmutableMap.of( - Optional.of("FINAL_COUNT"), functionCall("count", ImmutableList.of())), + Optional.of("FINAL_COUNT"), functionCall("count", false, ImmutableList.of(anySymbol()))), ImmutableMap.of(), Optional.empty(), - SINGLE, + FINAL, exchange(LOCAL, REPARTITION, - exchange(REMOTE_STREAMING, GATHER, - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); + aggregation( + singleGroupingSet("REGION_KEY"), + ImmutableMap.of( + Optional.of("PARTIAL_COUNT"), functionCall("count", false, ImmutableList.of())), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + exchange(REMOTE_STREAMING, GATHER, + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))); assertQuery(groupBy); String join = format("SELECT * " + diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeSystemQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeSystemQueries.java index a3fea4bf1d061..b4e60124674ae 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeSystemQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeSystemQueries.java @@ -21,7 +21,8 @@ import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual; import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters; -import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; @@ -67,10 +68,13 @@ public void testTasks() anyTree( aggregation( Collections.emptyMap(), - SINGLE, + FINAL, exchange(LOCAL, GATHER, - exchange(REMOTE_STREAMING, GATHER, - tableScan("tasks")))))); + aggregation( + Collections.emptyMap(), + PARTIAL, + exchange(REMOTE_STREAMING, GATHER, + tableScan("tasks"))))))); } @Test