diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java index e774f9b0955d9..3ff1769071468 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java @@ -54,6 +54,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; @@ -133,77 +134,97 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider // In order to preserve the fixpoint, we will "pretend" the newly added C2 table scan is part of C1's job to maintain. for (ConnectorId connectorId : connectorIds.build()) { Set optimizers = connectorOptimizers.get(connectorId); - if (optimizers == null) { + if (optimizers == null || optimizers.isEmpty()) { continue; } + ImmutableMap.Builder, Set> optimizersWithConnectorRange = ImmutableMap.builder(); + List currentConnectors = null; + ImmutableSet.Builder currentGroup = null; + for (ConnectorPlanOptimizer optimizer : optimizers) { + List supportedConnectors = optimizer.getSupportedConnectorIds().isEmpty() + ? ImmutableList.of(connectorId) + : optimizer.getSupportedConnectorIds(); + + if (!supportedConnectors.equals(currentConnectors)) { + if (currentGroup != null) { + optimizersWithConnectorRange.put(currentConnectors, currentGroup.build()); + } + currentConnectors = supportedConnectors; + currentGroup = ImmutableSet.builder(); + } + currentGroup.add(optimizer); + } + optimizersWithConnectorRange.put(currentConnectors, currentGroup.build()); + ImmutableMap.Builder contextMapBuilder = ImmutableMap.builder(); buildConnectorPlanNodeContext(plan, null, contextMapBuilder); Map contextMap = contextMapBuilder.build(); + for (Map.Entry, Set> entry : optimizersWithConnectorRange.build().entrySet()) { + // keep track of changed nodes; the keys are original nodes and the values are the new nodes + Map updates = new HashMap<>(); + + // process connector optimizers + for (PlanNode node : contextMap.keySet()) { + // For a subtree with root `node` to be a max closure, the following conditions must hold: + // * The subtree with root `node` is a closure. + // * `node` has no parent, or the subtree with root as `node`'s parent is not a closure. + ConnectorPlanNodeContext context = contextMap.get(node); + if (!context.isClosure(connectorId, session, entry.getKey()) || + !context.getParent().isPresent() || + contextMap.get(context.getParent().get()).isClosure(connectorId, session, entry.getKey())) { + continue; + } - // keep track of changed nodes; the keys are original nodes and the values are the new nodes - Map updates = new HashMap<>(); - - // process connector optimizers - for (PlanNode node : contextMap.keySet()) { - // For a subtree with root `node` to be a max closure, the following conditions must hold: - // * The subtree with root `node` is a closure. - // * `node` has no parent, or the subtree with root as `node`'s parent is not a closure. - ConnectorPlanNodeContext context = contextMap.get(node); - if (!context.isClosure(connectorId, session) || - !context.getParent().isPresent() || - contextMap.get(context.getParent().get()).isClosure(connectorId, session)) { - continue; - } - - PlanNode newNode = node; + PlanNode newNode = node; - // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here) - for (ConnectorPlanOptimizer optimizer : optimizers) { - long start = System.nanoTime(); - newNode = optimizer.optimize(newNode, session.toConnectorSession(connectorId), variableAllocator, idAllocator); - if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) { - session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start); + // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here) + for (ConnectorPlanOptimizer optimizer : entry.getValue()) { + long start = System.nanoTime(); + newNode = optimizer.optimize(newNode, session.toConnectorSession(connectorId), variableAllocator, idAllocator); + if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) { + session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start); + } } - } - if (node != newNode) { - // the optimizer has allocated a new PlanNode - checkState( - containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()), - "the connector optimizer from %s returns a node that does not cover all output before optimization", - connectorId); + if (node != newNode) { + // the optimizer has allocated a new PlanNode + checkState( + containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()), + "the connector optimizer from %s returns a node that does not cover all output before optimization", + connectorId); - updates.put(node, newNode); - } - } - // up to this point, we have a set of updated nodes; need to recursively update their parents - - // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm) - // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes - Queue originalNodes = new LinkedList<>(updates.keySet()); - while (!originalNodes.isEmpty()) { - PlanNode originalNode = originalNodes.poll(); - - if (!contextMap.get(originalNode).getParent().isPresent()) { - // originalNode must be the root; update the plan - plan = updates.get(originalNode); - continue; + updates.put(node, newNode); + } } + // up to this point, we have a set of updated nodes; need to recursively update their parents + + // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm) + // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes + Queue originalNodes = new LinkedList<>(updates.keySet()); + while (!originalNodes.isEmpty()) { + PlanNode originalNode = originalNodes.poll(); + + if (!contextMap.get(originalNode).getParent().isPresent()) { + // originalNode must be the root; update the plan + plan = updates.get(originalNode); + continue; + } - PlanNode originalParent = contextMap.get(originalNode).getParent().get(); + PlanNode originalParent = contextMap.get(originalNode).getParent().get(); - // need to create a new parent given the child has changed; the new parent needs to point to the new child. - // if a node has been updated, it will occur in `updates`; otherwise, just use the original node - ImmutableList.Builder newChildren = ImmutableList.builder(); - originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child))); - PlanNode newParent = originalParent.replaceChildren(newChildren.build()); + // need to create a new parent given the child has changed; the new parent needs to point to the new child. + // if a node has been updated, it will occur in `updates`; otherwise, just use the original node + ImmutableList.Builder newChildren = ImmutableList.builder(); + originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child))); + PlanNode newParent = originalParent.replaceChildren(newChildren.build()); - // mark the new parent as updated - updates.put(originalParent, newParent); + // mark the new parent as updated + updates.put(originalParent, newParent); - // enqueue the parent node in order to recursively update its ancestors - originalNodes.add(originalParent); + // enqueue the parent node in order to recursively update its ancestors + originalNodes.add(originalParent); + } } } @@ -306,17 +327,16 @@ public Set> getReachablePlanNodeTypes() return reachablePlanNodeTypes; } - boolean isClosure(ConnectorId connectorId, Session session) + boolean isClosure(ConnectorId connectorId, Session session, List supportedConnectorId) { // check if all children can reach the only connector boolean includeValuesNode = isIncludeValuesNodeInConnectorOptimizer(session); Set connectorIds = includeValuesNode ? reachableConnectors.stream().filter(x -> !x.equals(EMPTY_CONNECTOR_ID)).collect(toImmutableSet()) : reachableConnectors; - if (connectorIds.size() != 1 || !connectorIds.contains(connectorId)) { - return false; + if (connectorIds.contains(connectorId) && new HashSet<>(supportedConnectorId).containsAll(connectorIds) && supportedConnectorId.size() == connectorIds.size()) { + // check if all children are accessible by connectors + return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes); } - - // check if all children are accessible by connectors - return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes); + return false; } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java index 62b96f0973778..3ed25119466af 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java @@ -223,6 +223,223 @@ public void testAddFilterToTableScan() TypeProvider.viewOf(ImmutableMap.of("a", BIGINT, "b", BIGINT))); } + @Test + public void testMultipleConnectorOptimization() + { + PlanNode plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + PlanNode actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))))); + + ConnectorPlanOptimizer crossConnectorUnionOptimizer = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(crossConnectorUnionOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))))); + + plan = output( + union( + filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT), + filter(tableScan("cat2", "a", "b"), TRUE_CONSTANT), + filter(tableScan("cat3", "a", "b"), TRUE_CONSTANT)), + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer12 = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer12), + new ConnectorId("cat3"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b")), + SimpleTableScanMatcher.tableScan("cat3", TRUE_CONSTANT)))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), // This union only contains supported connectors + tableScan("cat4", "a", "b")), // cat4 in separate part of plan + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(crossConnectorUnionOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))), + SimpleTableScanMatcher.tableScan("cat4", "a", "b")))); + + plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b")), + "a"); + + ConnectorPlanOptimizer singleConnectorOptimizer1 = addFilterToTableScan(TRUE_CONSTANT); + ConnectorPlanOptimizer singleConnectorOptimizer2 = noop(); + ConnectorPlanOptimizer multiConnectorOptimizer13 = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat3"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(singleConnectorOptimizer1, multiConnectorOptimizer13), + new ConnectorId("cat2"), ImmutableSet.of(singleConnectorOptimizer2), + new ConnectorId("cat3"), ImmutableSet.of(singleConnectorOptimizer1))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b"))))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat3", "a", "b")), // This inner union has exactly cat1, cat3 + union( + tableScan("cat2", "a", "b"), + tableScan("cat4", "a", "b"))), // This inner union has cat2, cat4 + "a"); + + ConnectorPlanOptimizer exactMatchOptimizer = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat3"))); // Only supports cat1 and cat3 + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(exactMatchOptimizer), + new ConnectorId("cat2"), ImmutableSet.of(filterPushdown()), + new ConnectorId("cat4"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b"))), + PlanMatchPattern.union( + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + SimpleTableScanMatcher.tableScan("cat4", "a", "b"))))); + + plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b")), + "a"); + + ConnectorPlanOptimizer partialCoverageOptimizer = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); // Only supports cat1, cat2 + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(partialCoverageOptimizer), + new ConnectorId("cat3"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + SimpleTableScanMatcher.tableScan("cat1", "a", "b"), + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + SimpleTableScanMatcher.tableScan("cat3", "a", "b")))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), // This inner union has exactly cat1, cat2 + union( + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b"))), // This inner union has exactly cat2, cat3 + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer12v2 = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + ConnectorPlanOptimizer multiConnectorOptimizer23 = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat2"), new ConnectorId("cat3"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer12v2), + new ConnectorId("cat2"), ImmutableSet.of(multiConnectorOptimizer23), + new ConnectorId("cat3"), ImmutableSet.of(noop()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))), + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b")))))); + } + private TableScanNode tableScan(String connectorName, String... columnNames) { return PLAN_BUILDER.tableScan( @@ -303,6 +520,42 @@ private static ConnectorPlanOptimizer noop() return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan; } + private static ConnectorPlanOptimizer createMultiConnectorOptimizer(java.util.List supportedConnectors) + { + return new ConnectorPlanOptimizer() + { + @Override + public PlanNode optimize(PlanNode maxSubplan, com.facebook.presto.spi.ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) + { + return maxSubplan.accept(new TestMultiConnectorOptimizationVisitor(supportedConnectors, idAllocator), null); + } + + @Override + public java.util.List getSupportedConnectorIds() + { + return supportedConnectors; + } + }; + } + + private static ConnectorPlanOptimizer createCrossConnectorUnionOptimizer(java.util.List supportedConnectors) + { + return new ConnectorPlanOptimizer() + { + @Override + public PlanNode optimize(PlanNode maxSubplan, com.facebook.presto.spi.ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) + { + return maxSubplan.accept(new TestCrossConnectorUnionVisitor(supportedConnectors, idAllocator), null); + } + + @Override + public java.util.List getSupportedConnectorIds() + { + return supportedConnectors; + } + }; + } + private static class TestPlanOptimizationVisitor extends PlanVisitor { @@ -404,6 +657,88 @@ public PlanNode visitTableScan(TableScanNode node, Void context) } } + /** + * Multi-connector visitor that adds filters to table scans from supported connectors + */ + private static class TestMultiConnectorOptimizationVisitor + extends TestPlanOptimizationVisitor + { + private final java.util.List supportedConnectors; + private final PlanNodeIdAllocator idAllocator; + + TestMultiConnectorOptimizationVisitor(java.util.List supportedConnectors, PlanNodeIdAllocator idAllocator) + { + this.supportedConnectors = supportedConnectors; + this.idAllocator = idAllocator; + } + + @Override + public PlanNode visitTableScan(TableScanNode node, Void context) + { + if (supportedConnectors.contains(node.getTable().getConnectorId())) { + return new FilterNode(Optional.empty(), idAllocator.getNextId(), node, TRUE_CONSTANT); + } + return node; + } + } + + /** + * Multi-connector visitor that optimizes unions across different connectors + */ + private static class TestCrossConnectorUnionVisitor + extends TestPlanOptimizationVisitor + { + private final java.util.List supportedConnectors; + private final PlanNodeIdAllocator idAllocator; + + TestCrossConnectorUnionVisitor(java.util.List supportedConnectors, PlanNodeIdAllocator idAllocator) + { + this.supportedConnectors = supportedConnectors; + this.idAllocator = idAllocator; + } + + @Override + public PlanNode visitUnion(UnionNode node, Void context) + { + Set foundConnectors = new java.util.HashSet<>(); + boolean hasMultipleConnectors = false; + + for (PlanNode source : node.getSources()) { + if (source instanceof TableScanNode) { + ConnectorId connectorId = ((TableScanNode) source).getTable().getConnectorId(); + if (supportedConnectors.contains(connectorId)) { + foundConnectors.add(connectorId); + if (foundConnectors.size() > 1) { + hasMultipleConnectors = true; + break; + } + } + } + } + + if (hasMultipleConnectors) { + ImmutableList.Builder newSources = ImmutableList.builder(); + for (PlanNode source : node.getSources()) { + if (source instanceof TableScanNode) { + TableScanNode tableScan = (TableScanNode) source; + if (supportedConnectors.contains(tableScan.getTable().getConnectorId())) { + newSources.add(new FilterNode(Optional.empty(), idAllocator.getNextId(), tableScan, TRUE_CONSTANT)); + } + else { + newSources.add(source); + } + } + else { + newSources.add(source.accept(this, context)); + } + } + return node.replaceChildren(newSources.build()); + } + + return super.visitUnion(node, context); + } + } + /** * A simplified table scan matcher for multiple-connector support. * The goal is to test plan structural matching rather than table scan details diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorPlanOptimizer.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorPlanOptimizer.java index 1915138d33c03..b7c0114b51d89 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorPlanOptimizer.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorPlanOptimizer.java @@ -16,6 +16,9 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import java.util.Collections; +import java.util.List; + /** * Given a PlanNode, return a transformed PlanNode. *

@@ -34,4 +37,20 @@ PlanNode optimize( ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator); + + /** + * Returns the list of connector IDs that this optimizer can run on when operating on + * subplans that span multiple connectors. + *

+ * If this method returns an empty list (the default), the optimizer will only be applied + * to subplans that belong exclusively to the connector that registered this optimizer. + *

+ * If this method returns a non-empty list, the optimizer will be applied to subplans + * that contain table scans from exactly the connectors specified in the returned list. + * This allows cross-connector optimizations for federated queries. + */ + default List getSupportedConnectorIds() + { + return Collections.emptyList(); + } }