Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.PlanVisitor;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.RowExpression;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.expressions.translator.FunctionTranslator.buildFunctionTranslator;
import static com.facebook.presto.expressions.translator.RowExpressionTreeTranslator.translateWith;
import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies;
import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class JdbcComputePushdown
Expand Down Expand Up @@ -123,7 +130,9 @@ public PlanNode visitFilter(FilterNode node, Void context)

TableScanNode oldTableScanNode = (TableScanNode) node.getSource();
TableHandle oldTableHandle = oldTableScanNode.getTable();
JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle();
if (!oldTableHandle.getLayout().isPresent()) {
return node;
}

RowExpression predicate = expressionOptimizer.optimize(node.getPredicate(), OPTIMIZED, session);
predicate = logicalRowExpressions.convertToConjunctiveNormalForm(predicate);
Expand All @@ -132,32 +141,70 @@ public PlanNode visitFilter(FilterNode node, Void context)
jdbcFilterToSqlTranslator,
oldTableScanNode.getAssignments());

// TODO if jdbcExpression is not present, walk through translated subtree to find out which parts can be pushed down
if (!oldTableHandle.getLayout().isPresent() || !jdbcExpression.getTranslated().isPresent()) {
Optional<JdbcExpression> translated = jdbcExpression.getTranslated();
JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle();
// All filter can be pushed down
if (translated.isPresent()) {
return createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated);
}

// Find out which parts can be pushed down
List<RowExpression> remainingExpressions = new ArrayList<>();
List<JdbcExpression> translatedExpressions = new ArrayList<>();

List<RowExpression> rowExpressions = LogicalRowExpressions.extractConjuncts(predicate);
for (RowExpression expression : rowExpressions) {
TranslatedExpression<JdbcExpression> translatedExpression = translateWith(
expression,
jdbcFilterToSqlTranslator,
oldTableScanNode.getAssignments());

if (!translatedExpression.getTranslated().isPresent()) {
remainingExpressions.add(expression);
}
else {
translatedExpressions.add(translatedExpression.getTranslated().get());
}
}

// no filter can be pushed down
if (!remainingExpressions.isEmpty() && translatedExpressions.isEmpty()) {
return node;
}

List<String> sqlBodies = mergeSqlBodies(translatedExpressions);
List<ConstantExpression> variableBindings = mergeVariableBindings(translatedExpressions);
translated = Optional.of(new JdbcExpression(format("%s", Joiner.on(" AND ").join(sqlBodies)), variableBindings));
TableScanNode newTableScanNode = createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated);

return new FilterNode(idAllocator.getNextId(), newTableScanNode, logicalRowExpressions.combineConjuncts(remainingExpressions));
}

private TableScanNode createNewTableScanNode(
TableScanNode oldTableScanNode,
TableHandle oldTableHandle,
JdbcTableHandle oldConnectorTable,
Optional<JdbcExpression> additionalPredicate)
{
JdbcTableLayoutHandle oldTableLayoutHandle = (JdbcTableLayoutHandle) oldTableHandle.getLayout().get();
JdbcTableLayoutHandle newTableLayoutHandle = new JdbcTableLayoutHandle(
oldConnectorTable,
oldTableLayoutHandle.getTupleDomain(),
jdbcExpression.getTranslated());
additionalPredicate);

TableHandle tableHandle = new TableHandle(
oldTableHandle.getConnectorId(),
oldTableHandle.getConnectorHandle(),
oldTableHandle.getTransaction(),
Optional.of(newTableLayoutHandle));

TableScanNode newTableScanNode = new TableScanNode(
return new TableScanNode(
idAllocator.getNextId(),
tableHandle,
oldTableScanNode.getOutputVariables(),
oldTableScanNode.getAssignments(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint());

return new FilterNode(idAllocator.getNextId(), newTableScanNode, node.getPredicate());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.Optional;

import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated;
import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies;
import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -131,14 +133,8 @@ public TranslatedExpression<JdbcExpression> translateSpecialForm(SpecialFormExpr
return untranslated(specialForm, translatedExpressions);
}

List<String> sqlBodies = jdbcExpressions.stream()
.map(JdbcExpression::getExpression)
.map(sql -> '(' + sql + ')')
.collect(toImmutableList());
List<ConstantExpression> variableBindings = jdbcExpressions.stream()
.map(JdbcExpression::getBoundConstantValues)
.flatMap(List::stream)
.collect(toImmutableList());
List<String> sqlBodies = mergeSqlBodies(jdbcExpressions);
List<ConstantExpression> variableBindings = mergeVariableBindings(jdbcExpressions);

switch (specialForm.getForm()) {
case AND:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,20 @@ public static List<ConstantExpression> forwardBindVariables(JdbcExpression... jd
.flatMap(List::stream)
.collect(toImmutableList());
}

public static List<String> mergeSqlBodies(List<JdbcExpression> jdbcExpressions)
{
return jdbcExpressions.stream()
.map(JdbcExpression::getExpression)
.map(sql -> '(' + sql + ')')
.collect(toImmutableList());
}

public static List<ConstantExpression> mergeVariableBindings(List<JdbcExpression> jdbcExpressions)
{
return jdbcExpressions.stream()
.map(JdbcExpression::getBoundConstantValues)
.flatMap(List::stream)
.collect(toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testJdbcComputePushdownAll()
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(
actual,
PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns));
}

@Test
Expand All @@ -151,7 +151,7 @@ public void testJdbcComputePushdownBooleanOperations()
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(
actual,
PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns));
}

@Test
Expand Down Expand Up @@ -197,9 +197,7 @@ public void testJdbcComputePushdownWithConstants()

ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(actual, PlanMatchPattern.filter(
expression,
JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns));
}

@Test
Expand All @@ -220,6 +218,82 @@ public void testJdbcComputePushdownNotOperator()
TupleDomain.none(),
Optional.of(new JdbcExpression("(('c1') AND ((NOT('c2'))))")));

ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns));
}

@Test
public void testJdbcComputePartialPushdown()
{
String table = "test_table";
String schema = "test_schema";

// CAST(c1 AS varchar(1024)) = '123' cannot be pushed down, but c1 + c2 = 3 can.
String expression = "CAST(c1 AS varchar(1024)) = '123' and c1 + c2 = 3";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet());
JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
jdbcTableHandle,
TupleDomain.none(),
Optional.of(new JdbcExpression("((('c1' + 'c2') = ?))", ImmutableList.of(new ConstantExpression(3L, INTEGER)))));

ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(actual, PlanMatchPattern.filter(
"CAST(c1 AS varchar(1024)) = '123'",
JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
}

@Test
public void testJdbcComputePartialPushdownWithOrOperator()
{
String table = "test_table";
String schema = "test_schema";

// CAST(c1 AS varchar(1024)) = '123' cannot be pushed down, but c1 + c2 = 3 or c1 <> c2 can.
String expression = "CAST(c1 AS varchar(1024)) = '123' and (c1 + c2 = 3 or c1 <> c2)";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet());
JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
jdbcTableHandle,
TupleDomain.none(),
Optional.of(new JdbcExpression("((((('c1' + 'c2') = ?)) OR (('c1' <> 'c2'))))", ImmutableList.of(new ConstantExpression(3L, INTEGER)))));

ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(actual, PlanMatchPattern.filter(
"CAST(c1 AS varchar(1024)) = '123'",
JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
}

@Test
public void testJdbcComputeNoPushdown()
{
String table = "test_table";
String schema = "test_schema";

// no filter can be pushed down
String expression = "CAST(c1 AS varchar(1024)) = '123' and ((c1 - c2) > c2 or c1 <> c2)";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet());
JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
jdbcTableHandle,
TupleDomain.none(),
Optional.empty());

ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
assertPlanMatch(actual, PlanMatchPattern.filter(
Expand Down