diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java index 812cfd8512df9..9f66c9f7b3715 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java @@ -28,17 +28,24 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.PlanVisitor; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; +import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import java.util.Set; import static com.facebook.presto.expressions.translator.FunctionTranslator.buildFunctionTranslator; import static com.facebook.presto.expressions.translator.RowExpressionTreeTranslator.translateWith; +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies; +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class JdbcComputePushdown @@ -123,7 +130,9 @@ public PlanNode visitFilter(FilterNode node, Void context) TableScanNode oldTableScanNode = (TableScanNode) node.getSource(); TableHandle oldTableHandle = oldTableScanNode.getTable(); - JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle(); + if (!oldTableHandle.getLayout().isPresent()) { + return node; + } RowExpression predicate = expressionOptimizer.optimize(node.getPredicate(), OPTIMIZED, session); predicate = logicalRowExpressions.convertToConjunctiveNormalForm(predicate); @@ -132,16 +141,56 @@ public PlanNode visitFilter(FilterNode node, Void context) jdbcFilterToSqlTranslator, oldTableScanNode.getAssignments()); - // TODO if jdbcExpression is not present, walk through translated subtree to find out which parts can be pushed down - if (!oldTableHandle.getLayout().isPresent() || !jdbcExpression.getTranslated().isPresent()) { + Optional translated = jdbcExpression.getTranslated(); + JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle(); + // All filter can be pushed down + if (translated.isPresent()) { + return createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated); + } + + // Find out which parts can be pushed down + List remainingExpressions = new ArrayList<>(); + List translatedExpressions = new ArrayList<>(); + + List rowExpressions = LogicalRowExpressions.extractConjuncts(predicate); + for (RowExpression expression : rowExpressions) { + TranslatedExpression translatedExpression = translateWith( + expression, + jdbcFilterToSqlTranslator, + oldTableScanNode.getAssignments()); + + if (!translatedExpression.getTranslated().isPresent()) { + remainingExpressions.add(expression); + } + else { + translatedExpressions.add(translatedExpression.getTranslated().get()); + } + } + + // no filter can be pushed down + if (!remainingExpressions.isEmpty() && translatedExpressions.isEmpty()) { return node; } + List sqlBodies = mergeSqlBodies(translatedExpressions); + List variableBindings = mergeVariableBindings(translatedExpressions); + translated = Optional.of(new JdbcExpression(format("%s", Joiner.on(" AND ").join(sqlBodies)), variableBindings)); + TableScanNode newTableScanNode = createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated); + + return new FilterNode(idAllocator.getNextId(), newTableScanNode, logicalRowExpressions.combineConjuncts(remainingExpressions)); + } + + private TableScanNode createNewTableScanNode( + TableScanNode oldTableScanNode, + TableHandle oldTableHandle, + JdbcTableHandle oldConnectorTable, + Optional additionalPredicate) + { JdbcTableLayoutHandle oldTableLayoutHandle = (JdbcTableLayoutHandle) oldTableHandle.getLayout().get(); JdbcTableLayoutHandle newTableLayoutHandle = new JdbcTableLayoutHandle( oldConnectorTable, oldTableLayoutHandle.getTupleDomain(), - jdbcExpression.getTranslated()); + additionalPredicate); TableHandle tableHandle = new TableHandle( oldTableHandle.getConnectorId(), @@ -149,15 +198,13 @@ public PlanNode visitFilter(FilterNode node, Void context) oldTableHandle.getTransaction(), Optional.of(newTableLayoutHandle)); - TableScanNode newTableScanNode = new TableScanNode( + return new TableScanNode( idAllocator.getNextId(), tableHandle, oldTableScanNode.getOutputVariables(), oldTableScanNode.getAssignments(), oldTableScanNode.getCurrentConstraint(), oldTableScanNode.getEnforcedConstraint()); - - return new FilterNode(idAllocator.getNextId(), newTableScanNode, node.getPredicate()); } } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java index 90f0c786cd31b..b8e47412e0672 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java @@ -49,6 +49,8 @@ import java.util.Optional; import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated; +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies; +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -131,14 +133,8 @@ public TranslatedExpression translateSpecialForm(SpecialFormExpr return untranslated(specialForm, translatedExpressions); } - List sqlBodies = jdbcExpressions.stream() - .map(JdbcExpression::getExpression) - .map(sql -> '(' + sql + ')') - .collect(toImmutableList()); - List variableBindings = jdbcExpressions.stream() - .map(JdbcExpression::getBoundConstantValues) - .flatMap(List::stream) - .collect(toImmutableList()); + List sqlBodies = mergeSqlBodies(jdbcExpressions); + List variableBindings = mergeVariableBindings(jdbcExpressions); switch (specialForm.getForm()) { case AND: diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java index f0feef55129f8..26931f15ee790 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java @@ -38,4 +38,20 @@ public static List forwardBindVariables(JdbcExpression... jd .flatMap(List::stream) .collect(toImmutableList()); } + + public static List mergeSqlBodies(List jdbcExpressions) + { + return jdbcExpressions.stream() + .map(JdbcExpression::getExpression) + .map(sql -> '(' + sql + ')') + .collect(toImmutableList()); + } + + public static List mergeVariableBindings(List jdbcExpressions) + { + return jdbcExpressions.stream() + .map(JdbcExpression::getBoundConstantValues) + .flatMap(List::stream) + .collect(toImmutableList()); + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java index ae182f4c5ec3b..27438ae8bb4b5 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java @@ -126,7 +126,7 @@ public void testJdbcComputePushdownAll() PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); assertPlanMatch( actual, - PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); } @Test @@ -151,7 +151,7 @@ public void testJdbcComputePushdownBooleanOperations() PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); assertPlanMatch( actual, - PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); } @Test @@ -197,9 +197,7 @@ public void testJdbcComputePushdownWithConstants() ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); - assertPlanMatch(actual, PlanMatchPattern.filter( - expression, - JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); } @Test @@ -220,6 +218,82 @@ public void testJdbcComputePushdownNotOperator() TupleDomain.none(), Optional.of(new JdbcExpression("(('c1') AND ((NOT('c2'))))"))); + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); + } + + @Test + public void testJdbcComputePartialPushdown() + { + String table = "test_table"; + String schema = "test_schema"; + + // CAST(c1 AS varchar(1024)) = '123' cannot be pushed down, but c1 + c2 = 3 can. + String expression = "CAST(c1 AS varchar(1024)) = '123' and c1 + c2 = 3"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.of(new JdbcExpression("((('c1' + 'c2') = ?))", ImmutableList.of(new ConstantExpression(3L, INTEGER))))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, PlanMatchPattern.filter( + "CAST(c1 AS varchar(1024)) = '123'", + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + } + + @Test + public void testJdbcComputePartialPushdownWithOrOperator() + { + String table = "test_table"; + String schema = "test_schema"; + + // CAST(c1 AS varchar(1024)) = '123' cannot be pushed down, but c1 + c2 = 3 or c1 <> c2 can. + String expression = "CAST(c1 AS varchar(1024)) = '123' and (c1 + c2 = 3 or c1 <> c2)"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.of(new JdbcExpression("((((('c1' + 'c2') = ?)) OR (('c1' <> 'c2'))))", ImmutableList.of(new ConstantExpression(3L, INTEGER))))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, PlanMatchPattern.filter( + "CAST(c1 AS varchar(1024)) = '123'", + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + } + + @Test + public void testJdbcComputeNoPushdown() + { + String table = "test_table"; + String schema = "test_schema"; + + // no filter can be pushed down + String expression = "CAST(c1 AS varchar(1024)) = '123' and ((c1 - c2) > c2 or c1 <> c2)"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.empty()); + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); assertPlanMatch(actual, PlanMatchPattern.filter(