diff --git a/presto-base-jdbc/pom.xml b/presto-base-jdbc/pom.xml index e09fe83a02697..0ba93bc831cb1 100644 --- a/presto-base-jdbc/pom.xml +++ b/presto-base-jdbc/pom.xml @@ -90,6 +90,7 @@ presto-spi + com.facebook.presto presto-expressions @@ -172,6 +173,12 @@ presto-tests test + + + com.facebook.presto + presto-parser + test + diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java index 19b837533eadd..fb855c7a8b4f4 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java @@ -136,6 +136,12 @@ public void destroy() connectionFactory.close(); } + @Override + public String getIdentifierQuote() + { + return identifierQuote; + } + @Override public final Set getSchemaNames(JdbcIdentity identity) { @@ -265,7 +271,7 @@ public ConnectorSplitSource getSplits(JdbcIdentity identity, JdbcTableLayoutHand tableHandle.getSchemaName(), tableHandle.getTableName(), layoutHandle.getTupleDomain(), - Optional.empty()); + layoutHandle.getAdditionalPredicate()); return new FixedSplitSource(ImmutableList.of(jdbcSplit)); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java index 24c54cc16d41d..6c4779e411c8d 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java @@ -38,6 +38,8 @@ default boolean schemaExists(JdbcIdentity identity, String schema) return getSchemaNames(identity).contains(schema); } + String getIdentifierQuote(); + Set getSchemaNames(JdbcIdentity identity); List getTableNames(JdbcIdentity identity, Optional schema); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java index 7878e3a0023aa..b14656b4b1b3f 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java @@ -15,15 +15,20 @@ import com.facebook.airlift.bootstrap.LifeCycleManager; import com.facebook.airlift.log.Logger; +import com.facebook.presto.plugin.jdbc.optimization.JdbcPlanOptimizerProvider; import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; +import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableSet; @@ -55,6 +60,10 @@ public class JdbcConnector private final Set procedures; private final ConcurrentMap transactions = new ConcurrentHashMap<>(); + private final FunctionMetadataManager functionManager; + private final StandardFunctionResolution functionResolution; + private final RowExpressionService rowExpressionService; + private final JdbcClient jdbcClient; @Inject public JdbcConnector( @@ -64,7 +73,11 @@ public JdbcConnector( JdbcRecordSetProvider jdbcRecordSetProvider, JdbcPageSinkProvider jdbcPageSinkProvider, Optional accessControl, - Set procedures) + Set procedures, + FunctionMetadataManager functionManager, + StandardFunctionResolution functionResolution, + RowExpressionService rowExpressionService, + JdbcClient jdbcClient) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.jdbcMetadataFactory = requireNonNull(jdbcMetadataFactory, "jdbcMetadataFactory is null"); @@ -73,6 +86,21 @@ public JdbcConnector( this.jdbcPageSinkProvider = requireNonNull(jdbcPageSinkProvider, "jdbcPageSinkProvider is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); + this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); + this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); + } + + @Override + public ConnectorPlanOptimizerProvider getConnectorPlanOptimizerProvider() + { + return new JdbcPlanOptimizerProvider( + jdbcClient, + functionManager, + functionResolution, + rowExpressionService.getDeterminismEvaluator(), + rowExpressionService.getExpressionOptimizer()); } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java index 234702b41ce91..7866025905ef2 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java @@ -19,6 +19,9 @@ import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpressionService; import com.google.inject.Injector; import com.google.inject.Module; @@ -62,7 +65,14 @@ public Connector create(String catalogName, Map requiredConfig, requireNonNull(requiredConfig, "requiredConfig is null"); try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { - Bootstrap app = new Bootstrap(new JdbcModule(catalogName), module); + Bootstrap app = new Bootstrap( + binder -> { + binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); + binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); + binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); + }, + new JdbcModule(catalogName), + module); Injector injector = app .strictConfig() diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java index b0e987d3302ae..19929e1054944 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java @@ -85,7 +85,7 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, SchemaTableName public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { JdbcTableHandle tableHandle = (JdbcTableHandle) table; - ConnectorTableLayout layout = new ConnectorTableLayout(new JdbcTableLayoutHandle(tableHandle, constraint.getSummary())); + ConnectorTableLayout layout = new ConnectorTableLayout(new JdbcTableLayoutHandle(tableHandle, constraint.getSummary(), Optional.empty())); return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java index 912ca7f39e11e..f7a1ed702aab3 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.HostAddress; @@ -36,7 +37,7 @@ public class JdbcSplit private final String schemaName; private final String tableName; private final TupleDomain tupleDomain; - private final Optional additionalPredicate; + private final Optional additionalPredicate; @JsonCreator public JdbcSplit( @@ -45,7 +46,7 @@ public JdbcSplit( @JsonProperty("schemaName") @Nullable String schemaName, @JsonProperty("tableName") String tableName, @JsonProperty("tupleDomain") TupleDomain tupleDomain, - @JsonProperty("additionalProperty") Optional additionalPredicate) + @JsonProperty("additionalProperty") Optional additionalPredicate) { this.connectorId = requireNonNull(connectorId, "connector id is null"); this.catalogName = catalogName; @@ -88,7 +89,7 @@ public TupleDomain getTupleDomain() } @JsonProperty - public Optional getAdditionalPredicate() + public Optional getAdditionalPredicate() { return additionalPredicate; } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java index e1c5a9a4446b2..0bea87df5c37c 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.predicate.TupleDomain; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -28,14 +30,23 @@ public class JdbcTableLayoutHandle { private final JdbcTableHandle table; private final TupleDomain tupleDomain; + private final Optional additionalPredicate; @JsonCreator public JdbcTableLayoutHandle( @JsonProperty("table") JdbcTableHandle table, - @JsonProperty("tupleDomain") TupleDomain domain) + @JsonProperty("tupleDomain") TupleDomain domain, + @JsonProperty("additionalPredicate") Optional additionalPredicate) { this.table = requireNonNull(table, "table is null"); this.tupleDomain = requireNonNull(domain, "tupleDomain is null"); + this.additionalPredicate = additionalPredicate; + } + + @JsonProperty + public Optional getAdditionalPredicate() + { + return additionalPredicate; } @JsonProperty @@ -61,13 +72,14 @@ public boolean equals(Object o) } JdbcTableLayoutHandle that = (JdbcTableLayoutHandle) o; return Objects.equals(table, that.table) && - Objects.equals(tupleDomain, that.tupleDomain); + Objects.equals(tupleDomain, that.tupleDomain) && + Objects.equals(additionalPredicate, that.additionalPredicate); } @Override public int hashCode() { - return Objects.hash(table, tupleDomain); + return Objects.hash(table, tupleDomain, additionalPredicate); } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java index 8de9f4265e42e..fa27d8c4a1fae 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.Range; @@ -102,7 +103,7 @@ public PreparedStatement buildSql( String table, List columns, TupleDomain tupleDomain, - Optional additionalPredicate) + Optional additionalPredicate) throws SQLException { StringBuilder sql = new StringBuilder(); @@ -133,8 +134,11 @@ public PreparedStatement buildSql( if (additionalPredicate.isPresent()) { clauses = ImmutableList.builder() .addAll(clauses) - .add(additionalPredicate.get()) + .add(additionalPredicate.get().getExpression()) .build(); + accumulator.addAll(additionalPredicate.get().getBoundConstantValues().stream() + .map(constantExpression -> new TypeAndValue(constantExpression.getType(), constantExpression.getValue())) + .collect(ImmutableList.toImmutableList())); } if (!clauses.isEmpty()) { sql.append(" WHERE ") diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java index fec957589b265..812cfd8512df9 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.plugin.jdbc.optimization; +import com.facebook.presto.expressions.LogicalRowExpressions; +import com.facebook.presto.expressions.translator.TranslatedExpression; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; import com.facebook.presto.plugin.jdbc.JdbcTableLayoutHandle; import com.facebook.presto.spi.ConnectorPlanOptimizer; @@ -28,24 +30,47 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; import com.google.common.collect.ImmutableList; import java.util.Optional; +import java.util.Set; +import static com.facebook.presto.expressions.translator.FunctionTranslator.buildFunctionTranslator; +import static com.facebook.presto.expressions.translator.RowExpressionTreeTranslator.translateWith; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static java.util.Objects.requireNonNull; public class JdbcComputePushdown implements ConnectorPlanOptimizer { private final ExpressionOptimizer expressionOptimizer; + private final JdbcFilterToSqlTranslator jdbcFilterToSqlTranslator; + private final LogicalRowExpressions logicalRowExpressions; public JdbcComputePushdown( FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator, - ExpressionOptimizer expressionOptimizer) + ExpressionOptimizer expressionOptimizer, + String identifierQuote, + Set> functionTranslators) { - this.expressionOptimizer = expressionOptimizer; + requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + requireNonNull(identifierQuote, "identifierQuote is null"); + requireNonNull(functionTranslators, "functionTranslators is null"); + requireNonNull(determinismEvaluator, "determinismEvaluator is null"); + requireNonNull(functionResolution, "functionResolution is null"); + + this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null"); + this.jdbcFilterToSqlTranslator = new JdbcFilterToSqlTranslator( + functionMetadataManager, + buildFunctionTranslator(functionTranslators), + identifierQuote); + this.logicalRowExpressions = new LogicalRowExpressions( + determinismEvaluator, + functionResolution, + functionMetadataManager); } @Override @@ -100,18 +125,23 @@ public PlanNode visitFilter(FilterNode node, Void context) TableHandle oldTableHandle = oldTableScanNode.getTable(); JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle(); - // TODO: remove dependency on oldTableLayoutHandle, currently it needs oldTableLayoutHandle to get predicate - if (!oldTableHandle.getLayout().isPresent()) { + RowExpression predicate = expressionOptimizer.optimize(node.getPredicate(), OPTIMIZED, session); + predicate = logicalRowExpressions.convertToConjunctiveNormalForm(predicate); + TranslatedExpression jdbcExpression = translateWith( + predicate, + jdbcFilterToSqlTranslator, + oldTableScanNode.getAssignments()); + + // TODO if jdbcExpression is not present, walk through translated subtree to find out which parts can be pushed down + if (!oldTableHandle.getLayout().isPresent() || !jdbcExpression.getTranslated().isPresent()) { return node; } - // TODO: FilterRowExpression is currently mocked, needs to be implemented - JdbcTableLayoutHandle oldTableLayoutHandle = (JdbcTableLayoutHandle) oldTableHandle.getLayout().get(); - // TODO: add pushdownResult to new TableLayoutHandle JdbcTableLayoutHandle newTableLayoutHandle = new JdbcTableLayoutHandle( oldConnectorTable, - oldTableLayoutHandle.getTupleDomain()); + oldTableLayoutHandle.getTupleDomain(), + jdbcExpression.getTranslated()); TableHandle tableHandle = new TableHandle( oldTableHandle.getConnectorId(), diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcExpression.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcExpression.java new file mode 100644 index 0000000000000..80618e0fa18b7 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcExpression.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization; + +import com.facebook.presto.spi.relation.ConstantExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class JdbcExpression +{ + private final String expression; + private final List boundConstantValues; + + public JdbcExpression(String expression) + { + this(expression, ImmutableList.of()); + } + + @JsonCreator + public JdbcExpression( + @JsonProperty("translatedString") String expression, + @JsonProperty("boundConstantValues") List constantBindValues) + { + this.expression = requireNonNull(expression, "expression is null"); + this.boundConstantValues = requireNonNull(constantBindValues, "boundConstantValues is null"); + } + + @JsonProperty + public String getExpression() + { + return expression; + } + + /** + * Constant expressions are not added to the expression String. Instead they appear as "?" in the query. + * This is because we would potentially lose precision on double values. Hence when we make a PreparedStatement + * out of the SQL string replacing every "?" by it's corresponding actual bindValue. + * + * @return List of constants to replace in the SQL string. + */ + @JsonProperty + public List getBoundConstantValues() + { + return boundConstantValues; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JdbcExpression that = (JdbcExpression) o; + return expression.equals(that.expression) && + boundConstantValues.equals(that.boundConstantValues); + } + + @Override + public int hashCode() + { + return Objects.hash(expression, boundConstantValues); + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java new file mode 100644 index 0000000000000..3be5bd98e5aea --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization; + +import com.facebook.presto.expressions.translator.FunctionTranslator; +import com.facebook.presto.expressions.translator.RowExpressionTranslator; +import com.facebook.presto.expressions.translator.RowExpressionTreeTranslator; +import com.facebook.presto.expressions.translator.TranslatedExpression; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.CharType; +import com.facebook.presto.spi.type.DateType; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.IntegerType; +import com.facebook.presto.spi.type.RealType; +import com.facebook.presto.spi.type.SmallintType; +import com.facebook.presto.spi.type.TimeType; +import com.facebook.presto.spi.type.TimeWithTimeZoneType; +import com.facebook.presto.spi.type.TimestampType; +import com.facebook.presto.spi.type.TimestampWithTimeZoneType; +import com.facebook.presto.spi.type.TinyintType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class JdbcFilterToSqlTranslator + extends RowExpressionTranslator> +{ + private final FunctionMetadataManager functionMetadataManager; + private final FunctionTranslator functionTranslator; + private final String quote; + + public JdbcFilterToSqlTranslator(FunctionMetadataManager functionMetadataManager, FunctionTranslator functionTranslator, String quote) + { + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.functionTranslator = requireNonNull(functionTranslator, "functionTranslator is null"); + this.quote = requireNonNull(quote, "quote is null"); + } + + @Override + public TranslatedExpression translateConstant(ConstantExpression literal, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + if (isSupportedType(literal.getType())) { + return new TranslatedExpression<>( + Optional.of(new JdbcExpression("?", ImmutableList.of(literal))), + literal, + ImmutableList.of()); + } + return untranslated(literal); + } + + @Override + public TranslatedExpression translateVariable(VariableReferenceExpression variable, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.get(variable); + requireNonNull(columnHandle, format("Unrecognized variable %s", variable)); + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(quote + columnHandle.getColumnName().replace(quote, quote + quote) + quote)), + variable, + ImmutableList.of()); + } + + @Override + public TranslatedExpression translateLambda(LambdaDefinitionExpression lambda, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + return untranslated(lambda); + } + + @Override + public TranslatedExpression translateCall(CallExpression call, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + List> translatedExpressions = call.getArguments().stream() + .map(expression -> rowExpressionTreeTranslator.rewrite(expression, context)) + .collect(toImmutableList()); + + FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(call.getFunctionHandle()); + + try { + return functionTranslator.translate(functionMetadata, call, translatedExpressions); + } + catch (Throwable t) { + // no-op + } + return untranslated(call, translatedExpressions); + } + + @Override + public TranslatedExpression translateSpecialForm(SpecialFormExpression specialForm, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + List> translatedExpressions = specialForm.getArguments().stream() + .map(expression -> rowExpressionTreeTranslator.rewrite(expression, context)) + .collect(toImmutableList()); + + List jdbcExpressions = translatedExpressions.stream() + .map(TranslatedExpression::getTranslated) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + + if (jdbcExpressions.size() < translatedExpressions.size()) { + return untranslated(specialForm, translatedExpressions); + } + + List sqlBodies = jdbcExpressions.stream() + .map(JdbcExpression::getExpression) + .map(sql -> '(' + sql + ')') + .collect(toImmutableList()); + List variableBindings = jdbcExpressions.stream() + .map(JdbcExpression::getBoundConstantValues) + .flatMap(List::stream) + .collect(toImmutableList()); + + switch (specialForm.getForm()) { + case AND: + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("(%s)", Joiner.on(" AND ").join(sqlBodies)), variableBindings)), + specialForm, + translatedExpressions); + case OR: + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("(%s)", Joiner.on(" OR ").join(sqlBodies)), variableBindings)), + specialForm, + translatedExpressions); + case IN: + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("(%s IN (%s))", sqlBodies.get(0), Joiner.on(" , ").join(sqlBodies.subList(1, sqlBodies.size()))), variableBindings)), + specialForm, + translatedExpressions); + } + return untranslated(specialForm, translatedExpressions); + } + + private static boolean isSupportedType(Type type) + { + Type validType = requireNonNull(type, "type is null"); + return validType.equals(BigintType.BIGINT) || + validType.equals(TinyintType.TINYINT) || + validType.equals(SmallintType.SMALLINT) || + validType.equals(IntegerType.INTEGER) || + validType.equals(DoubleType.DOUBLE) || + validType.equals(RealType.REAL) || + validType.equals(BooleanType.BOOLEAN) || + validType.equals(DateType.DATE) || + validType.equals(TimeType.TIME) || + validType.equals(TimeWithTimeZoneType.TIME_WITH_TIME_ZONE) || + validType.equals(TimestampType.TIMESTAMP) || + validType.equals(TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE) || + validType instanceof VarcharType || + validType instanceof CharType; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java index af7c76867f952..94042314c7922 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.plugin.jdbc.optimization; +import com.facebook.presto.plugin.jdbc.JdbcClient; +import com.facebook.presto.plugin.jdbc.optimization.function.OperatorTranslators; import com.facebook.presto.spi.ConnectorPlanOptimizer; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.function.FunctionMetadataManager; @@ -20,6 +22,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.google.common.collect.ImmutableSet; +import com.google.inject.Inject; import java.util.Set; @@ -32,18 +35,21 @@ public class JdbcPlanOptimizerProvider private final StandardFunctionResolution functionResolution; private final DeterminismEvaluator determinismEvaluator; private final ExpressionOptimizer expressionOptimizer; + private final String identifierQuote; + @Inject public JdbcPlanOptimizerProvider( + JdbcClient jdbcClient, FunctionMetadataManager functionManager, StandardFunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator, ExpressionOptimizer expressionOptimizer) { - // TODO: Override getConnectorPlanOptimizer in JdbcConnector and add JdbcPlanOptimizer to it this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null"); this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null"); + this.identifierQuote = jdbcClient.getIdentifierQuote(); } @Override @@ -53,6 +59,13 @@ public Set getConnectorPlanOptimizers() functionManager, functionResolution, determinismEvaluator, - expressionOptimizer)); + expressionOptimizer, + identifierQuote, + getFunctionTranslators())); + } + + private Set> getFunctionTranslators() + { + return ImmutableSet.of(OperatorTranslators.class); } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java new file mode 100644 index 0000000000000..f0feef55129f8 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization.function; + +import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; +import com.facebook.presto.spi.relation.ConstantExpression; + +import java.util.Arrays; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class JdbcTranslationUtil +{ + private JdbcTranslationUtil() + { + } + + public static String infixOperation(String operator, JdbcExpression left, JdbcExpression right) + { + return String.format("(%s %s %s)", left.getExpression(), operator, right.getExpression()); + } + + public static List forwardBindVariables(JdbcExpression... jdbcExpressions) + { + return Arrays.stream(jdbcExpressions).map(JdbcExpression::getBoundConstantValues) + .flatMap(List::stream) + .collect(toImmutableList()); + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java new file mode 100644 index 0000000000000..ff47b33767df5 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization.function; + +import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.StandardTypes; + +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.forwardBindVariables; +import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.infixOperation; +import static com.facebook.presto.spi.function.OperatorType.ADD; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; + +public class OperatorTranslators +{ + private OperatorTranslators() + { + } + + @ScalarOperator(ADD) + @SqlType(StandardTypes.BIGINT) + public static JdbcExpression add(@SqlType(StandardTypes.BIGINT) JdbcExpression left, @SqlType(StandardTypes.BIGINT) JdbcExpression right) + { + return new JdbcExpression(infixOperation("+", left, right), forwardBindVariables(left, right)); + } + + @ScalarOperator(SUBTRACT) + @SqlType(StandardTypes.BIGINT) + public static JdbcExpression subtract(@SqlType(StandardTypes.BIGINT) JdbcExpression left, @SqlType(StandardTypes.BIGINT) JdbcExpression right) + { + return new JdbcExpression(infixOperation("-", left, right), forwardBindVariables(left, right)); + } + + @ScalarOperator(EQUAL) + @SqlType(StandardTypes.BOOLEAN) + public static JdbcExpression equal(@SqlType(StandardTypes.BIGINT) JdbcExpression left, @SqlType(StandardTypes.BIGINT) JdbcExpression right) + { + return new JdbcExpression(infixOperation("=", left, right), forwardBindVariables(left, right)); + } + + @ScalarOperator(NOT_EQUAL) + @SqlType(StandardTypes.BOOLEAN) + public static JdbcExpression notEqual(@SqlType(StandardTypes.BIGINT) JdbcExpression left, @SqlType(StandardTypes.BIGINT) JdbcExpression right) + { + return new JdbcExpression(infixOperation("<>", left, right), forwardBindVariables(left, right)); + } + + @ScalarFunction("not") + @SqlType(StandardTypes.BOOLEAN) + public static JdbcExpression not(@SqlType(StandardTypes.BOOLEAN) JdbcExpression expression) + { + return new JdbcExpression(String.format("(NOT(%s))", expression.getExpression()), expression.getBoundConstantValues()); + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcIntegrationSmokeTest.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcIntegrationSmokeTest.java index d0af6a8310fbf..df45f5de91445 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcIntegrationSmokeTest.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcIntegrationSmokeTest.java @@ -14,6 +14,7 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import org.testng.annotations.Test; import static com.facebook.presto.plugin.jdbc.JdbcQueryRunner.createJdbcQueryRunner; import static io.airlift.tpch.TpchTable.ORDERS; @@ -25,4 +26,47 @@ public TestJdbcIntegrationSmokeTest() { super(() -> createJdbcQueryRunner(ORDERS)); } + + @Test + public void testPlanOptimizerFilterPushdownCorrectness() + { + assertUpdate("CREATE TABLE test_pushdown1(col1 bigint, col2 bigint)"); + assertUpdate("INSERT INTO test_pushdown1 VALUES (5, 4), (2, 4), (4, 4)", "VALUES(3)"); + // PickTableLayout#pushPredicateIntoTableScan cannot push down this predicate however, JdbcPlanOptimizer can. + assertQuery("SELECT * FROM test_pushdown1 WHERE col1 + col2 = 8", "VALUES(4, 4)"); + } + + @Test + public void testPlanOptimizerAndDefaultFilterPushdownCorrectness() + { + assertUpdate("CREATE TABLE test_pushdown2(col1 bigint, col2 bigint)"); + assertUpdate("INSERT INTO test_pushdown2 VALUES (5, 4), (2, 4), (4, 4)", "VALUES(3)"); + // PickTableLayout#pushPredicateIntoTableScan and JdbcPlanOptimizer are both capable of pushing down these predicates. + assertQuery("SELECT * FROM test_pushdown2 WHERE col1 = 5 AND col2 = 4", "VALUES(5, 4)"); + } + + @Test + public void testIncompatibleFilterPushdownCorrectness() + { + assertUpdate("CREATE TABLE test_pushdown3(col1 bigint, col2 bigint)"); + assertUpdate("INSERT INTO test_pushdown3 VALUES (5, 4), (2, 4), (4, 4)", "VALUES(3)"); + // Neither PickTableLayout#pushPredicateIntoTableScan or JdbcPlanOptimizer are capable of pushing down these predicates. + assertQuery("SELECT * FROM test_pushdown3 WHERE NOT(col1 + col2 <= 6)", "VALUES (5, 4), (4, 4)"); + } + + @Test + public void testArithmetic() + { + assertUpdate("CREATE TABLE test_arithmetic(col1 bigint, col2 bigint)"); + assertUpdate("INSERT INTO test_arithmetic VALUES (5, 4), (2, 4), (4, 4)", "VALUES(3)"); + assertQuery("SELECT * FROM test_arithmetic WHERE NOT(col1 + col2 - 6 = 0)", "VALUES (5, 4), (4, 4)"); + } + + @Test + public void testBooleanNot() + { + assertUpdate("CREATE TABLE test_not(col1 bigint, col2 boolean)"); + assertUpdate("INSERT INTO test_not VALUES (5, true), (2, true), (4, false)", "VALUES(3)"); + assertQuery("SELECT col1 FROM test_not WHERE NOT(col2)", "VALUES(4)"); + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java index 85afd97a2b107..3e13a96ec4457 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java @@ -32,6 +32,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; @@ -180,7 +181,7 @@ public void testTupleDomain() private RecordCursor getCursor(JdbcTableHandle jdbcTableHandle, List columns, TupleDomain domain) { - JdbcTableLayoutHandle layoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, domain); + JdbcTableLayoutHandle layoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, domain, Optional.empty()); ConnectorSplitSource splits = jdbcClient.getSplits(IDENTITY, layoutHandle); JdbcSplit split = (JdbcSplit) getOnlyElement(getFutureValue(splits.getNextBatch(NOT_PARTITIONED, 1000)).getSplits()); diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java index cf3f470863419..e18a101177e91 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java @@ -25,6 +25,7 @@ import java.sql.SQLException; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Properties; import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; @@ -99,7 +100,7 @@ public JdbcSplit getSplit(String schemaName, String tableName) { JdbcIdentity identity = JdbcIdentity.from(session); JdbcTableHandle jdbcTableHandle = jdbcClient.getTableHandle(identity, new SchemaTableName(schemaName, tableName)); - JdbcTableLayoutHandle jdbcLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.all()); + JdbcTableLayoutHandle jdbcLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.all(), Optional.empty()); ConnectorSplitSource splits = jdbcClient.getSplits(identity, jdbcLayoutHandle); return (JdbcSplit) getOnlyElement(getFutureValue(splits.getNextBatch(NOT_PARTITIONED, 1000)).getSplits()); } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java index 4f0b449123589..4c5966175e096 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java @@ -17,23 +17,32 @@ import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; import com.facebook.presto.plugin.jdbc.JdbcTableLayoutHandle; +import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; +import com.facebook.presto.plugin.jdbc.optimization.function.OperatorTranslators; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.assertions.MatchResult; @@ -42,17 +51,29 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.assertions.SymbolAliases; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.TestingConnectorSession; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; +import java.sql.Types; import java.util.Arrays; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.function.Function.identity; @@ -63,32 +84,172 @@ public class TestJdbcComputePushdown private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); private static final PlanBuilder PLAN_BUILDER = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA); private static final PlanNodeIdAllocator ID_ALLOCATOR = new PlanNodeIdAllocator(); + private static final String CATALOG_NAME = "Jdbc"; + private static final String CONNECTOR_ID = new ConnectorId(CATALOG_NAME).toString(); + + private final TestingRowExpressionTranslator sqlToRowExpressionTranslator; private final JdbcComputePushdown jdbcComputePushdown; public TestJdbcComputePushdown() { - this.jdbcComputePushdown = new JdbcComputePushdown(null, null, null, null); + this.sqlToRowExpressionTranslator = new TestingRowExpressionTranslator(METADATA); + FunctionManager functionManager = METADATA.getFunctionManager(); + StandardFunctionResolution functionResolution = new FunctionResolution(functionManager); + DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager); + + this.jdbcComputePushdown = new JdbcComputePushdown( + functionManager, + functionResolution, + determinismEvaluator, + new RowExpressionOptimizer(METADATA), + "'", + getFunctionTranslators()); + } + + @Test + public void testJdbcComputePushdownAll() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "(c1 + c2) - c2"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.of(new JdbcExpression("(('c1' + 'c2') - 'c2')"))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch( + actual, + PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + } + + @Test + public void testJdbcComputePushdownBooleanOperations() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "(((c1 + c2) - c2 <> c2) OR c2 = c1) AND c1 <> c2"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.of(new JdbcExpression("((((((('c1' + 'c2') - 'c2') <> 'c2')) OR (('c2' = 'c1')))) AND (('c1' <> 'c2')))"))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch( + actual, + PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); } @Test - public void testJdbcComputePushdownIsNoop() + public void testJdbcComputePushdownUnsupported() { - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle("cat1", new SchemaTableName("schema", "table"), null, null, "table"); - JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none()); - TableHandle tableHandle = new TableHandle(new ConnectorId("Jdbc"), jdbcTableHandle, new ConnectorTransactionHandle() {}, Optional.of(jdbcTableLayoutHandle)); - PlanNode original = filter(tableScan(tableHandle, "a", "b"), TRUE_CONSTANT); + String table = "test_table"; + String schema = "test_schema"; + + String expression = "(c1 + c2) > c2"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + // Test should expect an empty entry for translatedSql since > is an unsupported function currently in the optimizer + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.empty()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); assertPlanMatch(actual, PlanMatchPattern.filter( - TRUE_CONSTANT.toString(), - JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle))); + expression, + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); } - private static VariableReferenceExpression newBigintVariable(String name) + @Test + public void testJdbcComputePushdownWithConstants() { - return new VariableReferenceExpression(name, BIGINT); + String table = "test_table"; + String schema = "test_schema"; + + String expression = "(c1 + c2) = 3"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); + PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.of(new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(Long.valueOf(3), INTEGER))))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, PlanMatchPattern.filter( + expression, + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + } + + @Test + public void testJdbcComputePushdownNotOperator() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "c1 AND NOT(c2)"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BOOLEAN, "c2", BOOLEAN)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + PlanNode original = filter(jdbcTableScan(schema, table, BOOLEAN, "c1", "c2"), rowExpression); + + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none(), + Optional.of(new JdbcExpression("(('c1') AND ((NOT('c2'))))"))); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, PlanMatchPattern.filter( + expression, + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + } + + private Set> getFunctionTranslators() + { + return ImmutableSet.of(OperatorTranslators.class); + } + + private static VariableReferenceExpression newVariable(String name, Type type) + { + return new VariableReferenceExpression(name, type); + } + + private static JdbcColumnHandle integerJdbcColumnHandle(String name) + { + return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BIGINT, 10, 0), BIGINT, false); + } + + private static JdbcColumnHandle booleanJdbcColumnHandle(String name) + { + return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BOOLEAN, 1, 0), BOOLEAN, false); + } + + private static JdbcColumnHandle getColumnHandleForVariable(String name, Type type) + { + return type.equals(BOOLEAN) ? booleanJdbcColumnHandle(name) : integerJdbcColumnHandle(name); } private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected) @@ -106,12 +267,18 @@ private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected, expected); } - private TableScanNode tableScan(TableHandle tableHandle, String... columnNames) + private TableScanNode jdbcTableScan(String schema, String table, Type type, String... columnNames) { + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.empty()); + TableHandle tableHandle = new TableHandle(new ConnectorId(CATALOG_NAME), jdbcTableHandle, new ConnectorTransactionHandle() {}, Optional.of(jdbcTableLayoutHandle)); + return PLAN_BUILDER.tableScan( tableHandle, - Arrays.stream(columnNames).map(TestJdbcComputePushdown::newBigintVariable).collect(toImmutableList()), - Arrays.stream(columnNames).map(TestJdbcComputePushdown::newBigintVariable).collect(toMap(identity(), variable -> new ColumnHandle() {}))); + Arrays.stream(columnNames).map(column -> newVariable(column, type)).collect(toImmutableList()), + Arrays.stream(columnNames) + .map(column -> newVariable(column, type)) + .collect(toMap(identity(), entry -> getColumnHandleForVariable(entry.getName(), type)))); } private FilterNode filter(PlanNode source, RowExpression predicate) @@ -123,15 +290,17 @@ private static final class JdbcTableScanMatcher implements Matcher { private final JdbcTableLayoutHandle jdbcTableLayoutHandle; + private final Set columns; - static PlanMatchPattern jdbcTableScanPattern(JdbcTableLayoutHandle jdbcTableLayoutHandle) + static PlanMatchPattern jdbcTableScanPattern(JdbcTableLayoutHandle jdbcTableLayoutHandle, Set columns) { - return node(TableScanNode.class).with(new JdbcTableScanMatcher(jdbcTableLayoutHandle)); + return node(TableScanNode.class).with(new JdbcTableScanMatcher(jdbcTableLayoutHandle, columns)); } - private JdbcTableScanMatcher(JdbcTableLayoutHandle jdbcTableLayoutHandle) + private JdbcTableScanMatcher(JdbcTableLayoutHandle jdbcTableLayoutHandle, Set columns) { this.jdbcTableLayoutHandle = jdbcTableLayoutHandle; + this.columns = columns; } @Override @@ -147,8 +316,16 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses TableScanNode tableScanNode = (TableScanNode) node; JdbcTableLayoutHandle layoutHandle = (JdbcTableLayoutHandle) tableScanNode.getTable().getLayout().get(); - if (jdbcTableLayoutHandle.getTable().equals(layoutHandle.getTable()) && jdbcTableLayoutHandle.getTupleDomain().equals(layoutHandle.getTupleDomain())) { - return MatchResult.match(); + if (jdbcTableLayoutHandle.getTable().equals(layoutHandle.getTable()) + && jdbcTableLayoutHandle.getTupleDomain().equals(layoutHandle.getTupleDomain()) + && ((!jdbcTableLayoutHandle.getAdditionalPredicate().isPresent() && !layoutHandle.getAdditionalPredicate().isPresent()) + || jdbcTableLayoutHandle.getAdditionalPredicate().get().getExpression().equals(layoutHandle.getAdditionalPredicate().get().getExpression()))) { + return MatchResult.match( + SymbolAliases.builder().putAll( + columns.stream() + .map(column -> ((JdbcColumnHandle) column).getColumnName()) + .collect(toMap(identity(), SymbolReference::new))) + .build()); } return MatchResult.NO_MATCH; diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/FunctionTranslator.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/FunctionTranslator.java index 54e34971661bc..f8c717ef770bb 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/FunctionTranslator.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/FunctionTranslator.java @@ -24,6 +24,9 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated; +import static com.facebook.presto.expressions.translator.TranslatorAnnotationParser.removeTypeParameters; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class FunctionTranslator @@ -39,13 +42,21 @@ public static FunctionTranslator buildFunctionTranslator(Set> tr return new FunctionTranslator<>(functionMappingBuilder.build()); } - public TranslatedExpression translate(FunctionMetadata functionMetadata, RowExpression original, List> translatedArguments) + public TranslatedExpression translate(FunctionMetadata functionMetadata, RowExpression original, List> translatedExpressions) throws Throwable { - if (!functionMapping.containsKey(functionMetadata)) { - return new TranslatedExpression<>(Optional.empty(), original, translatedArguments); + functionMetadata = removeTypeParameters(functionMetadata); + if (!functionMapping.containsKey(functionMetadata) + || !translatedExpressions.stream().map(TranslatedExpression::getTranslated).allMatch(Optional::isPresent)) { + return untranslated(original, translatedExpressions); } - return new TranslatedExpression<>(Optional.of((T) functionMapping.get(functionMetadata).invokeWithArguments(translatedArguments)), original, translatedArguments); + + List translatedArguments = translatedExpressions.stream() + .map(TranslatedExpression::getTranslated) + .map(Optional::get) + .collect(toImmutableList()); + + return new TranslatedExpression<>(Optional.of((T) functionMapping.get(functionMetadata).invokeWithArguments(translatedArguments)), original, translatedExpressions); } public TranslatedExpression translate(FunctionMetadata functionMetadata, RowExpression original, TranslatedExpression... translatedArguments) diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/RowExpressionTranslator.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/RowExpressionTranslator.java index f700cd76bdeff..2a2aa0d0725d9 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/RowExpressionTranslator.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/RowExpressionTranslator.java @@ -28,14 +28,14 @@ public TranslatedExpression translateConstant(ConstantExpression literal, C c return untranslated(literal); } - public TranslatedExpression translateVariable(VariableReferenceExpression reference, C context, RowExpressionTreeTranslator rowExpressionTreeTranslator) + public TranslatedExpression translateVariable(VariableReferenceExpression variable, C context, RowExpressionTreeTranslator rowExpressionTreeTranslator) { - return untranslated(reference); + return untranslated(variable); } - public TranslatedExpression translateLambda(LambdaDefinitionExpression reference, C context, RowExpressionTreeTranslator rowExpressionTreeTranslator) + public TranslatedExpression translateLambda(LambdaDefinitionExpression lambda, C context, RowExpressionTreeTranslator rowExpressionTreeTranslator) { - return untranslated(reference); + return untranslated(lambda); } public TranslatedExpression translateCall(CallExpression call, C context, RowExpressionTreeTranslator rowExpressionTreeTranslator) diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatedExpression.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatedExpression.java index 33e4e5ac0788f..63322275bc169 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatedExpression.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatedExpression.java @@ -53,4 +53,9 @@ public static TranslatedExpression untranslated(RowExpression originalExp { return new TranslatedExpression<>(Optional.empty(), originalExpression, ImmutableList.of()); } + + public static TranslatedExpression untranslated(RowExpression originalExpression, List> translatedArguments) + { + return new TranslatedExpression<>(Optional.empty(), originalExpression, translatedArguments); + } } diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java index 659a22620b83b..2eac8d1026332 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/translator/TranslatorAnnotationParser.java @@ -66,7 +66,7 @@ private static TypeSignature removeTypeParameters(TypeSignature typeSignature) return new TypeSignature(typeSignature.getBase()); } - private static FunctionMetadata removeTypeParameters(FunctionMetadata metadata) + public static FunctionMetadata removeTypeParameters(FunctionMetadata metadata) { ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); for (TypeSignature typeSignature : metadata.getArgumentTypes()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java index 0a30d230f1e6d..00191d0166d53 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionTranslator.java @@ -218,7 +218,7 @@ public TranslatedExpression translateCall(CallExpression callExpression, return functionTranslator.translate(functionMetadata, callExpression, translatedExpressions); } catch (Throwable t) { - return untranslated(callExpression); + return untranslated(callExpression, translatedExpressions); } } @@ -242,9 +242,9 @@ public TranslatedExpression translateSpecialForm(SpecialFormExpression s } @Override - public TranslatedExpression translateVariable(VariableReferenceExpression reference, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) + public TranslatedExpression translateVariable(VariableReferenceExpression variable, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) { - return new TranslatedExpression<>(Optional.of(reference.getName()), reference, emptyList()); + return new TranslatedExpression<>(Optional.of(variable.getName()), variable, emptyList()); } } @@ -252,53 +252,44 @@ public static class TestFunctions { @ScalarFunction @SqlType(StandardTypes.BIGINT) - public static String bitwiseAnd(@SqlType(StandardTypes.BIGINT) TranslatedExpression left, @SqlType(StandardTypes.BIGINT) TranslatedExpression right) + public static String bitwiseAnd(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right) { - assertTrue(left.getTranslated().isPresent()); - assertTrue(right.getTranslated().isPresent()); - return left.getTranslated().get() + " BITWISE_AND " + right.getTranslated().get(); + return left + " BITWISE_AND " + right; } @ScalarFunction("ln") @SqlType(StandardTypes.DOUBLE) - public static String ln(@SqlType(StandardTypes.DOUBLE) TranslatedExpression sql) + public static String ln(@SqlType(StandardTypes.DOUBLE) String sql) { - assertTrue(sql.getTranslated().isPresent()); - return "LNof(" + sql.getTranslated().get() + ")"; + return "LNof(" + sql + ")"; } @ScalarFunction("ceil") @SqlType(StandardTypes.DOUBLE) - public static String ceil(@SqlType(StandardTypes.BOOLEAN) TranslatedExpression sql) + public static String ceil(@SqlType(StandardTypes.BOOLEAN) String sql) { - assertTrue(sql.getTranslated().isPresent()); - return "CEILof(" + sql.getTranslated().get() + ")"; + return "CEILof(" + sql + ")"; } @ScalarFunction("not") @SqlType(StandardTypes.BOOLEAN) - public static String not(@SqlType(StandardTypes.BOOLEAN) TranslatedExpression sql) + public static String not(@SqlType(StandardTypes.BOOLEAN) String sql) { - assertTrue(sql.getTranslated().isPresent()); - return "NOT_2 " + sql.getTranslated().get(); + return "NOT_2 " + sql; } @ScalarOperator(OperatorType.ADD) @SqlType(StandardTypes.BIGINT) - public static String plus(@SqlType(StandardTypes.BIGINT) TranslatedExpression left, @SqlType(StandardTypes.BIGINT) TranslatedExpression right) + public static String plus(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right) { - assertTrue(left.getTranslated().isPresent()); - assertTrue(right.getTranslated().isPresent()); - return left.getTranslated().get() + " -|- " + right.getTranslated().get(); + return left + " -|- " + right; } @ScalarOperator(OperatorType.LESS_THAN) @SqlType(StandardTypes.BOOLEAN) - public static String lessThan(@SqlType(StandardTypes.BIGINT) TranslatedExpression left, @SqlType(StandardTypes.BIGINT) TranslatedExpression right) + public static String lessThan(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right) { - assertTrue(left.getTranslated().isPresent()); - assertTrue(right.getTranslated().isPresent()); - return left.getTranslated().get() + " LT " + right.getTranslated().get(); + return left + " LT " + right; } } }