Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.Field;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.ComparisonExpression;
Expand Down Expand Up @@ -511,6 +512,25 @@ public static boolean containsSystemTableScan(PlanNode plan, Lookup lookup)
.matches();
}

/// Checks whether a node is directly on top of a system table scan without exchange in between
public static boolean directlyOnSystemTableScan(PlanNode plan, Lookup lookup)
{
plan = lookup.resolve(plan);
for (PlanNode source : plan.getSources()) {
source = lookup.resolve(source);
if (source instanceof TableScanNode && isInternalSystemConnector(((TableScanNode) source).getTable().getConnectorId())) {
return true;
}
if (source instanceof ExchangeNode) {
continue;
}
if (directlyOnSystemTableScan(source, lookup)) {
return true;
}
}
return false;
}

public static boolean isConstant(RowExpression expression, Type type, Object value)
{
return expression instanceof ConstantExpression &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public interface Lookup
default PlanNode resolve(PlanNode node)
{
if (node instanceof GroupReference) {
return resolveGroup(node).collect(toOptional()).get();
return resolveGroup(node).collect(toOptional()).orElse(node);
}
return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
Expand Down Expand Up @@ -60,7 +61,6 @@
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER;
import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
Expand Down Expand Up @@ -166,8 +166,13 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
return Result.empty();
}

// System table scan must be run in Java on coordinator and partial aggregation output may not be compatible with Velox
if (nativeExecution && containsSystemTableScan(exchangeNode, context.getLookup())) {
// For native execution:
// Partial aggregation result from Java coordinator task is not compatible with native worker.
// System table scan must be run in on coordinator and addExchange would always add a GatherExchange on top of it.
// We should never push partial aggregation past the GatherExchange.
if (nativeExecution
&& exchangeNode.getType() == GATHER
&& PlannerUtils.directlyOnSystemTableScan(exchangeNode, context.getLookup())) {
return Result.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.MaterializedRow;
Expand All @@ -33,6 +35,7 @@
import com.facebook.presto.tests.DistributedQueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -71,10 +74,13 @@
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createSupplier;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createTableToTestHiddenColumns;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR;
import static com.facebook.presto.spi.plan.ExchangeEncoding.ROW_WISE;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.GroupingSetDescriptor;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
Expand Down Expand Up @@ -1520,66 +1526,72 @@ public void testSystemTables()
"AS " +
"SELECT nationkey, name, comment, regionkey FROM nation", tableName));

String filter = format("SELECT regionkey FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName);
String groupingSet = format("SELECT count(*) FROM \"%s\" GROUP BY GROUPING SETS ((regionkey), ())", partitionsTableName);
assertPlan(
filter,
anyTree(
exchange(REMOTE_STREAMING, GATHER,
filter(
"REGION_KEY % 3 = 1",
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))));
assertQuery(filter);

String project = format("SELECT regionkey + 1 FROM \"%s\"", partitionsTableName);
assertPlan(
project,
anyTree(
exchange(REMOTE_STREAMING, GATHER,
project(
ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")),
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))));
assertQuery(project);

String filterProject = format("SELECT regionkey + 1 FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName);
assertPlan(
filterProject,
anyTree(
exchange(REMOTE_STREAMING, GATHER,
project(
ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")),
filter(
"REGION_KEY % 3 = 1",
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
assertQuery(filterProject);
groupingSet,
PlanMatchPattern.output(project(
aggregation(
new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)),
ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of(anySymbol()))),
ImmutableMap.of(),
Optional.of(new Symbol("groupid")),
FINAL,
exchange(LOCAL, REPARTITION,
aggregation(
new GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)),
ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of())),
ImmutableList.of(),
ImmutableMap.of(),
Optional.of(new Symbol("groupid")),
PARTIAL,
PlanMatchPattern.groupingSet(
ImmutableList.of(ImmutableList.of("REGION_KEY"), ImmutableList.of()),
ImmutableMap.of(),
"groupid",
ImmutableMap.of("regionkey$gid", expression("REGION_KEY")),
exchange(REMOTE_STREAMING, GATHER,
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))))));

String aggregation = format("SELECT count(*), sum(regionkey) FROM \"%s\"", partitionsTableName);
assertPlan(
aggregation,
anyTree(
PlanMatchPattern.output(
aggregation(
ImmutableMap.of(
"FINAL_COUNT", functionCall("count", ImmutableList.of()),
"FINAL_SUM", functionCall("sum", ImmutableList.of("REGION_KEY"))),
SINGLE,
"FINAL_COUNT", functionCall("count", false, ImmutableList.of(anySymbol())),
"FINAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))),
FINAL,
exchange(LOCAL, GATHER,
exchange(REMOTE_STREAMING, GATHER,
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
aggregation(
ImmutableMap.of(
"PARTIAL_COUNT", functionCall("count", false, ImmutableList.of()),
"PARTIAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))),
PARTIAL,
exchange(REMOTE_STREAMING, GATHER,
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))));
assertQuery(aggregation);

String groupBy = format("SELECT regionkey, count(*) FROM \"%s\" GROUP BY regionkey", partitionsTableName);
assertPlan(
groupBy,
anyTree(
PlanMatchPattern.output(
aggregation(
singleGroupingSet("REGION_KEY"),
ImmutableMap.of(
Optional.of("FINAL_COUNT"), functionCall("count", ImmutableList.of())),
Optional.of("FINAL_COUNT"), functionCall("count", false, ImmutableList.of(anySymbol()))),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
FINAL,
exchange(LOCAL, REPARTITION,
exchange(REMOTE_STREAMING, GATHER,
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
aggregation(
singleGroupingSet("REGION_KEY"),
ImmutableMap.of(
Optional.of("PARTIAL_COUNT"), functionCall("count", false, ImmutableList.of())),
ImmutableMap.of(),
Optional.empty(),
PARTIAL,
exchange(REMOTE_STREAMING, GATHER,
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))));
assertQuery(groupBy);

String join = format("SELECT * " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
Expand Down Expand Up @@ -67,10 +68,13 @@ public void testTasks()
anyTree(
aggregation(
Collections.emptyMap(),
SINGLE,
FINAL,
exchange(LOCAL, GATHER,
exchange(REMOTE_STREAMING, GATHER,
tableScan("tasks"))))));
aggregation(
Collections.emptyMap(),
PARTIAL,
exchange(REMOTE_STREAMING, GATHER,
tableScan("tasks")))))));
}

@Test
Expand Down
Loading