diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java index 4961a18a2ff5c..66ca36f05c7da 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java @@ -19,6 +19,8 @@ import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.optimization.JdbcQueryGeneratorContext; +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorSession; @@ -53,6 +55,8 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -136,6 +140,66 @@ public void destroy() connectionFactory.close(); } + @Override + public boolean supportsLimit() + { + return limitFunction().isPresent(); + } + + protected Optional> limitFunction() + { + return Optional.empty(); + } + + private Function applyLimit(long limit) + { + return query -> limitFunction() + .orElseThrow(() -> new PrestoException(JDBC_ERROR, "limitFunction is not set!")) + .apply(query, limit); + } + + @Override + public boolean supportsTopN(List sortItems) + { + if (!topNFunction().isPresent()) { + return false; + } + throw new UnsupportedOperationException("topNFunction() implemented without implementing supportsTopN()"); + } + + protected Optional topNFunction() + { + return Optional.empty(); + } + + private Function applyTopN(List sortItems, long limit) + { + return query -> topNFunction() + .orElseThrow(() -> new PrestoException(JDBC_ERROR, "topNFunction is not set!")) + .apply(query, sortItems, limit); + } + + @FunctionalInterface + public interface TopNFunction + { + String apply(String query, List sortItems, long limit); + } + + @Override + public String applyQueryTransformations(String query, JdbcQueryGeneratorContext context) + { + if (context.getLimit().isPresent()) { + if (context.getSortOrder().isPresent()) { + return applyTopN(context.getSortOrder().get(), context.getLimit().getAsLong()).apply(query); + } + else { + return applyLimit(context.getLimit().getAsLong()).apply(query); + } + } + + return query; + } + @Override public String getIdentifierQuote() { @@ -208,7 +272,8 @@ public JdbcTableHandle getTableHandle(JdbcIdentity identity, SchemaTableName sch schemaTableName, resultSet.getString("TABLE_CAT"), resultSet.getString("TABLE_SCHEM"), - resultSet.getString("TABLE_NAME"))); + resultSet.getString("TABLE_NAME"), + Optional.empty())); } if (tableHandles.isEmpty()) { return null; @@ -233,6 +298,7 @@ public List getColumns(ConnectorSession session, JdbcTableHand while (resultSet.next()) { JdbcTypeHandle typeHandle = new JdbcTypeHandle( resultSet.getInt("DATA_TYPE"), + Optional.ofNullable(resultSet.getString("TYPE_NAME")), resultSet.getInt("COLUMN_SIZE"), resultSet.getInt("DECIMAL_DIGITS")); Optional columnMapping = toPrestoType(session, typeHandle); @@ -271,7 +337,7 @@ public ConnectorSplitSource getSplits(JdbcIdentity identity, JdbcTableLayoutHand tableHandle.getSchemaName(), tableHandle.getTableName(), layoutHandle.getTupleDomain(), - layoutHandle.getAdditionalPredicate()); + tableHandle.getContext()); return new FixedSplitSource(ImmutableList.of(jdbcSplit)); } @@ -303,7 +369,7 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio split.getTableName(), columnHandles, split.getTupleDomain(), - split.getAdditionalPredicate()); + split.getContext()); } @Override @@ -551,7 +617,8 @@ public void rollbackCreateTable(JdbcIdentity identity, JdbcOutputTableHandle han new SchemaTableName(handle.getSchemaName(), handle.getTemporaryTableName()), handle.getCatalogName(), handle.getSchemaName(), - handle.getTemporaryTableName())); + handle.getTemporaryTableName(), + Optional.empty())); } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java index 4035ee3daeeed..081d629e85265 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java @@ -14,6 +14,8 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.plugin.jdbc.optimization.JdbcQueryGeneratorContext; +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorSession; @@ -38,6 +40,12 @@ default boolean schemaExists(JdbcIdentity identity, String schema) return getSchemaNames(identity).contains(schema); } + boolean supportsLimit(); + + boolean supportsTopN(List sortItems); + + String applyQueryTransformations(String query, JdbcQueryGeneratorContext context); + String getIdentifierQuote(); Set getSchemaNames(JdbcIdentity identity); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java index b14656b4b1b3f..c321d97ecbc06 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java @@ -15,6 +15,7 @@ import com.facebook.airlift.bootstrap.LifeCycleManager; import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.plugin.jdbc.optimization.JdbcPlanOptimizerProvider; import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorAccessControl; @@ -63,6 +64,7 @@ public class JdbcConnector private final FunctionMetadataManager functionManager; private final StandardFunctionResolution functionResolution; private final RowExpressionService rowExpressionService; + private final TypeManager typeManager; private final JdbcClient jdbcClient; @Inject @@ -77,6 +79,7 @@ public JdbcConnector( FunctionMetadataManager functionManager, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, + TypeManager typeManager, JdbcClient jdbcClient) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); @@ -89,6 +92,7 @@ public JdbcConnector( this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); } @@ -97,6 +101,7 @@ public ConnectorPlanOptimizerProvider getConnectorPlanOptimizerProvider() { return new JdbcPlanOptimizerProvider( jdbcClient, + typeManager, functionManager, functionResolution, rowExpressionService.getDeterminismEvaluator(), diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java index e31e547f41514..ca93f61b46780 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnectorFactory.java @@ -14,6 +14,7 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.Connector; @@ -67,6 +68,7 @@ public Connector create(String catalogName, Map requiredConfig, try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { Bootstrap app = new Bootstrap( binder -> { + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java index 26ca1c0066872..50673833a19cd 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java @@ -85,7 +85,7 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, SchemaTableName public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { JdbcTableHandle tableHandle = (JdbcTableHandle) table; - ConnectorTableLayout layout = new ConnectorTableLayout(new JdbcTableLayoutHandle(tableHandle, constraint.getSummary(), Optional.empty())); + ConnectorTableLayout layout = new ConnectorTableLayout(new JdbcTableLayoutHandle(tableHandle, constraint.getSummary())); return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java index bd3ff527ca5a5..d2bca8c6e1e65 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java @@ -14,7 +14,7 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.presto.common.predicate.TupleDomain; -import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; +import com.facebook.presto.plugin.jdbc.optimization.JdbcQueryGeneratorContext; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.HostAddress; @@ -39,7 +39,7 @@ public class JdbcSplit private final String schemaName; private final String tableName; private final TupleDomain tupleDomain; - private final Optional additionalPredicate; + private final Optional context; @JsonCreator public JdbcSplit( @@ -48,14 +48,14 @@ public JdbcSplit( @JsonProperty("schemaName") @Nullable String schemaName, @JsonProperty("tableName") String tableName, @JsonProperty("tupleDomain") TupleDomain tupleDomain, - @JsonProperty("additionalProperty") Optional additionalPredicate) + @JsonProperty("context") Optional context) { this.connectorId = requireNonNull(connectorId, "connector id is null"); this.catalogName = catalogName; this.schemaName = schemaName; this.tableName = requireNonNull(tableName, "table name is null"); this.tupleDomain = requireNonNull(tupleDomain, "tupleDomain is null"); - this.additionalPredicate = requireNonNull(additionalPredicate, "additionalPredicate is null"); + this.context = requireNonNull(context, "context is null"); } @JsonProperty @@ -91,9 +91,9 @@ public TupleDomain getTupleDomain() } @JsonProperty - public Optional getAdditionalPredicate() + public Optional getContext() { - return additionalPredicate; + return context; } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java index c9ec95841b14f..a8cc08af776d2 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.plugin.jdbc.optimization.JdbcQueryGeneratorContext; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; @@ -22,6 +23,7 @@ import javax.annotation.Nullable; import java.util.Objects; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -35,6 +37,7 @@ public final class JdbcTableHandle private final String catalogName; private final String schemaName; private final String tableName; + private final Optional context; @JsonCreator public JdbcTableHandle( @@ -42,13 +45,15 @@ public JdbcTableHandle( @JsonProperty("schemaTableName") SchemaTableName schemaTableName, @JsonProperty("catalogName") @Nullable String catalogName, @JsonProperty("schemaName") @Nullable String schemaName, - @JsonProperty("tableName") String tableName) + @JsonProperty("tableName") String tableName, + @JsonProperty("context") Optional context) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.catalogName = catalogName; this.schemaName = schemaName; this.tableName = requireNonNull(tableName, "tableName is null"); + this.context = requireNonNull(context, "context is null"); } @JsonProperty @@ -83,6 +88,12 @@ public String getTableName() return tableName; } + @JsonProperty + public Optional getContext() + { + return context; + } + @Override public boolean equals(Object obj) { diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java index 69f0cf615feda..2eb09e2e37ed5 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableLayoutHandle.java @@ -14,14 +14,12 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.presto.common.predicate.TupleDomain; -import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; -import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -30,23 +28,14 @@ public class JdbcTableLayoutHandle { private final JdbcTableHandle table; private final TupleDomain tupleDomain; - private final Optional additionalPredicate; @JsonCreator public JdbcTableLayoutHandle( @JsonProperty("table") JdbcTableHandle table, - @JsonProperty("tupleDomain") TupleDomain domain, - @JsonProperty("additionalPredicate") Optional additionalPredicate) + @JsonProperty("tupleDomain") TupleDomain domain) { this.table = requireNonNull(table, "table is null"); this.tupleDomain = requireNonNull(domain, "tupleDomain is null"); - this.additionalPredicate = additionalPredicate; - } - - @JsonProperty - public Optional getAdditionalPredicate() - { - return additionalPredicate; } @JsonProperty @@ -72,14 +61,13 @@ public boolean equals(Object o) } JdbcTableLayoutHandle that = (JdbcTableLayoutHandle) o; return Objects.equals(table, that.table) && - Objects.equals(tupleDomain, that.tupleDomain) && - Objects.equals(additionalPredicate, that.additionalPredicate); + Objects.equals(tupleDomain, that.tupleDomain); } @Override public int hashCode() { - return Objects.hash(table, tupleDomain, additionalPredicate); + return Objects.hash(table, tupleDomain); } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTypeHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTypeHandle.java index 250dbdbbe4c27..769d30916af7e 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTypeHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTypeHandle.java @@ -17,22 +17,26 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; public final class JdbcTypeHandle { private final int jdbcType; + private final Optional jdbcTypeName; private final int columnSize; private final int decimalDigits; @JsonCreator public JdbcTypeHandle( @JsonProperty("jdbcType") int jdbcType, + @JsonProperty("jdbcTypeName") Optional jdbcTypeName, @JsonProperty("columnSize") int columnSize, @JsonProperty("decimalDigits") int decimalDigits) { this.jdbcType = jdbcType; + this.jdbcTypeName = jdbcTypeName; this.columnSize = columnSize; this.decimalDigits = decimalDigits; } @@ -43,6 +47,12 @@ public int getJdbcType() return jdbcType; } + @JsonProperty + public Optional getJdbcTypeName() + { + return jdbcTypeName; + } + @JsonProperty public int getColumnSize() { @@ -58,7 +68,7 @@ public int getDecimalDigits() @Override public int hashCode() { - return Objects.hash(jdbcType, columnSize, decimalDigits); + return Objects.hash(jdbcType, jdbcTypeName, columnSize, decimalDigits); } @Override @@ -72,6 +82,7 @@ public boolean equals(Object o) } JdbcTypeHandle that = (JdbcTypeHandle) o; return jdbcType == that.jdbcType && + jdbcTypeName == that.jdbcTypeName && columnSize == that.columnSize && decimalDigits == that.decimalDigits; } @@ -81,6 +92,7 @@ public String toString() { return toStringHelper(this) .add("jdbcType", jdbcType) + .add("jdbcTypeName", jdbcTypeName) .add("columnSize", columnSize) .add("decimalDigits", decimalDigits) .toString(); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java index 0926e605ff169..866498af940db 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java @@ -32,6 +32,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; +import com.facebook.presto.plugin.jdbc.optimization.JdbcQueryGeneratorContext; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; import com.google.common.base.Joiner; @@ -105,7 +106,7 @@ public PreparedStatement buildSql( String table, List columns, TupleDomain tupleDomain, - Optional additionalPredicate) + Optional context) throws SQLException { StringBuilder sql = new StringBuilder(); @@ -133,21 +134,30 @@ public PreparedStatement buildSql( List accumulator = new ArrayList<>(); List clauses = toConjuncts(columns, tupleDomain, accumulator); - if (additionalPredicate.isPresent()) { - clauses = ImmutableList.builder() - .addAll(clauses) - .add(additionalPredicate.get().getExpression()) - .build(); - accumulator.addAll(additionalPredicate.get().getBoundConstantValues().stream() - .map(constantExpression -> new TypeAndValue(constantExpression.getType(), constantExpression.getValue())) - .collect(ImmutableList.toImmutableList())); + if (context.isPresent()) { + Optional additionalPredicate = context.get().getAdditionalPredicate(); + if (additionalPredicate.isPresent()) { + clauses = ImmutableList.builder() + .addAll(clauses) + .add(additionalPredicate.get().getExpression()) + .build(); + accumulator.addAll(additionalPredicate.get().getBoundConstantValues().stream() + .map(constantExpression -> new TypeAndValue(constantExpression.getType(), constantExpression.getValue())) + .collect(ImmutableList.toImmutableList())); + } } if (!clauses.isEmpty()) { sql.append(" WHERE ") .append(Joiner.on(" AND ").join(clauses)); } + + String sqlStr = sql.toString(); + if (context.isPresent()) { + sqlStr = client.applyQueryTransformations(sqlStr, context.get()); + } + sql.append(String.format("/* %s : %s */", session.getUser(), session.getQueryId())); - PreparedStatement statement = client.getPreparedStatement(connection, sql.toString()); + PreparedStatement statement = client.getPreparedStatement(connection, sqlStr); for (int i = 0; i < accumulator.size(); i++) { TypeAndValue typeAndValue = accumulator.get(i); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java index 9f66c9f7b3715..4b20f038e7b75 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java @@ -13,49 +13,67 @@ */ package com.facebook.presto.plugin.jdbc.optimization; +import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.translator.TranslatedExpression; +import com.facebook.presto.plugin.jdbc.JdbcClient; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; import com.facebook.presto.plugin.jdbc.JdbcTableLayoutHandle; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPlanOptimizer; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.PlanVisitor; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.expressions.translator.FunctionTranslator.buildFunctionTranslator; import static com.facebook.presto.expressions.translator.RowExpressionTreeTranslator.translateWith; import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies; import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Maps.immutableEntry; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class JdbcComputePushdown implements ConnectorPlanOptimizer { + private final JdbcClient jdbcClient; private final ExpressionOptimizer expressionOptimizer; private final JdbcFilterToSqlTranslator jdbcFilterToSqlTranslator; private final LogicalRowExpressions logicalRowExpressions; public JdbcComputePushdown( + JdbcClient jdbcClient, + TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator, @@ -68,11 +86,15 @@ public JdbcComputePushdown( requireNonNull(functionTranslators, "functionTranslators is null"); requireNonNull(determinismEvaluator, "determinismEvaluator is null"); requireNonNull(functionResolution, "functionResolution is null"); + requireNonNull(typeManager, "typeManager is null"); + this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null"); this.jdbcFilterToSqlTranslator = new JdbcFilterToSqlTranslator( + typeManager, functionMetadataManager, buildFunctionTranslator(functionTranslators), + functionResolution, identifierQuote); this.logicalRowExpressions = new LogicalRowExpressions( determinismEvaluator, @@ -87,11 +109,12 @@ public PlanNode optimize( VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) { - return maxSubplan.accept(new Visitor(session, idAllocator), null); + JdbcQueryGeneratorContext context = new JdbcQueryGeneratorContext(); + return maxSubplan.accept(new Visitor(session, idAllocator), context); } private class Visitor - extends PlanVisitor + extends PlanVisitor { private final ConnectorSession session; private final PlanNodeIdAllocator idAllocator; @@ -103,12 +126,12 @@ public Visitor(ConnectorSession session, PlanNodeIdAllocator idAllocator) } @Override - public PlanNode visitPlan(PlanNode node, Void context) + public PlanNode visitPlan(PlanNode node, JdbcQueryGeneratorContext context) { ImmutableList.Builder children = ImmutableList.builder(); boolean changed = false; for (PlanNode child : node.getSources()) { - PlanNode newChild = child.accept(this, null); + PlanNode newChild = child.accept(this, context); if (newChild != child) { changed = true; } @@ -122,7 +145,114 @@ public PlanNode visitPlan(PlanNode node, Void context) } @Override - public PlanNode visitFilter(FilterNode node, Void context) + public PlanNode visitProject(ProjectNode node, JdbcQueryGeneratorContext context) + { + PlanNode planNode = node.getSource().accept(this, context); + if (!(planNode instanceof TableScanNode)) { + if (node.getSource() == planNode) { + return node; + } + else { + return node.replaceChildren(ImmutableList.of(planNode)); + } + } + + Assignments assignments = node.getAssignments(); + boolean allVariableReferenceExpression = assignments.entrySet().stream() + .allMatch(assignment -> assignment.getValue() instanceof VariableReferenceExpression); + + if (allVariableReferenceExpression) { + TableScanNode oldTableScanNode = (TableScanNode) planNode; + LinkedHashMap newAssignments = node.getAssignments().getMap().entrySet() + .stream() + .map(assignment -> { + VariableReferenceExpression value = (VariableReferenceExpression) assignment.getValue(); + return immutableEntry(assignment.getKey(), oldTableScanNode.getAssignments().get(value)); + }).collect(Collectors.toMap(Map.Entry::getKey, + Map.Entry::getValue, + (val1, val2) -> val1, + LinkedHashMap::new)); + + return new TableScanNode( + idAllocator.getNextId(), + oldTableScanNode.getTable(), + node.getOutputVariables(), + newAssignments, + oldTableScanNode.getCurrentConstraint(), + oldTableScanNode.getEnforcedConstraint()); + } + + return node; + } + + @Override + public PlanNode visitLimit(LimitNode node, JdbcQueryGeneratorContext context) + { + PlanNode planNode = node.getSource().accept(this, context); + if (!(planNode instanceof TableScanNode)) { + if (node.getSource() == planNode) { + return node; + } + else { + return node.replaceChildren(ImmutableList.of(planNode)); + } + } + + if (!jdbcClient.supportsLimit()) { + return new LimitNode(idAllocator.getNextId(), planNode, node.getCount(), node.getStep()); + } + + long count = node.getCount(); + TableScanNode tableScanNode = (TableScanNode) planNode; + TableHandle handle = tableScanNode.getTable(); + Optional oldLayout = handle.getLayout(); + if (oldLayout.isPresent()) { + JdbcTableHandle oldConnectorTable = (JdbcTableHandle) handle.getConnectorHandle(); + return createNewTableScanNode(tableScanNode, handle, oldConnectorTable, context.withLimit(OptionalLong.of(count))); + } + + return node; + } + + @Override + public PlanNode visitTopN(TopNNode node, JdbcQueryGeneratorContext context) + { + PlanNode planNode = node.getSource().accept(this, context); + if (!(planNode instanceof TableScanNode)) { + if (node.getSource() == planNode) { + return node; + } + else { + return node.replaceChildren(ImmutableList.of(planNode)); + } + } + + TableScanNode tableScanNode = (TableScanNode) planNode; + Map assignments = tableScanNode.getAssignments(); + List sortItems = node.getOrderingScheme().getOrderByVariables().stream() + .map(orderBy -> { + verify(assignments.containsKey(orderBy), "assignments does not contain order by item: %s", orderBy.getName()); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) assignments.get(orderBy); + return new JdbcSortItem(columnHandle, node.getOrderingScheme().getOrdering(orderBy)); + }).collect(Collectors.toList()); + + if (!jdbcClient.supportsTopN(sortItems)) { + return new TopNNode(idAllocator.getNextId(), planNode, node.getCount(), node.getOrderingScheme(), node.getStep()); + } + + long count = node.getCount(); + TableHandle handle = tableScanNode.getTable(); + Optional oldLayout = handle.getLayout(); + if (oldLayout.isPresent()) { + JdbcTableHandle oldConnectorTable = (JdbcTableHandle) handle.getConnectorHandle(); + return createNewTableScanNode(tableScanNode, handle, oldConnectorTable, context.withTopN(Optional.of(sortItems), OptionalLong.of(count))); + } + + return node; + } + + @Override + public PlanNode visitFilter(FilterNode node, JdbcQueryGeneratorContext context) { if (!(node.getSource() instanceof TableScanNode)) { return node; @@ -134,8 +264,7 @@ public PlanNode visitFilter(FilterNode node, Void context) return node; } - RowExpression predicate = expressionOptimizer.optimize(node.getPredicate(), OPTIMIZED, session); - predicate = logicalRowExpressions.convertToConjunctiveNormalForm(predicate); + RowExpression predicate = logicalRowExpressions.convertToConjunctiveNormalForm(node.getPredicate()); TranslatedExpression jdbcExpression = translateWith( predicate, jdbcFilterToSqlTranslator, @@ -145,7 +274,7 @@ public PlanNode visitFilter(FilterNode node, Void context) JdbcTableHandle oldConnectorTable = (JdbcTableHandle) oldTableHandle.getConnectorHandle(); // All filter can be pushed down if (translated.isPresent()) { - return createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated); + return createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, context.withFilter(translated)); } // Find out which parts can be pushed down @@ -175,7 +304,8 @@ public PlanNode visitFilter(FilterNode node, Void context) List sqlBodies = mergeSqlBodies(translatedExpressions); List variableBindings = mergeVariableBindings(translatedExpressions); translated = Optional.of(new JdbcExpression(format("%s", Joiner.on(" AND ").join(sqlBodies)), variableBindings)); - TableScanNode newTableScanNode = createNewTableScanNode(oldTableScanNode, oldTableHandle, oldConnectorTable, translated); + TableScanNode newTableScanNode = createNewTableScanNode(oldTableScanNode, oldTableHandle, + oldConnectorTable, context.withFilter(translated)); return new FilterNode(idAllocator.getNextId(), newTableScanNode, logicalRowExpressions.combineConjuncts(remainingExpressions)); } @@ -184,17 +314,24 @@ private TableScanNode createNewTableScanNode( TableScanNode oldTableScanNode, TableHandle oldTableHandle, JdbcTableHandle oldConnectorTable, - Optional additionalPredicate) + JdbcQueryGeneratorContext context) { + JdbcTableHandle newJdbcTableHandle = new JdbcTableHandle( + oldConnectorTable.getConnectorId(), + oldConnectorTable.getSchemaTableName(), + oldConnectorTable.getCatalogName(), + oldConnectorTable.getSchemaName(), + oldConnectorTable.getTableName(), + Optional.of(context)); + JdbcTableLayoutHandle oldTableLayoutHandle = (JdbcTableLayoutHandle) oldTableHandle.getLayout().get(); JdbcTableLayoutHandle newTableLayoutHandle = new JdbcTableLayoutHandle( - oldConnectorTable, - oldTableLayoutHandle.getTupleDomain(), - additionalPredicate); + newJdbcTableHandle, + oldTableLayoutHandle.getTupleDomain()); TableHandle tableHandle = new TableHandle( oldTableHandle.getConnectorId(), - oldTableHandle.getConnectorHandle(), + newJdbcTableHandle, oldTableHandle.getTransaction(), Optional.of(newTableLayoutHandle)); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java index b8e47412e0672..95e75fc822e35 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcFilterToSqlTranslator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc.optimization; +import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.CharType; @@ -27,6 +28,7 @@ import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.expressions.translator.FunctionTranslator; import com.facebook.presto.expressions.translator.RowExpressionTranslator; @@ -34,11 +36,14 @@ import com.facebook.presto.expressions.translator.TranslatedExpression; import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.base.Joiner; @@ -48,6 +53,7 @@ import java.util.Map; import java.util.Optional; +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated; import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeSqlBodies; import static com.facebook.presto.plugin.jdbc.optimization.function.JdbcTranslationUtil.mergeVariableBindings; @@ -58,14 +64,23 @@ public class JdbcFilterToSqlTranslator extends RowExpressionTranslator> { + private final TypeManager typeManager; private final FunctionMetadataManager functionMetadataManager; private final FunctionTranslator functionTranslator; + private final StandardFunctionResolution standardFunctionResolution; private final String quote; - public JdbcFilterToSqlTranslator(FunctionMetadataManager functionMetadataManager, FunctionTranslator functionTranslator, String quote) + public JdbcFilterToSqlTranslator( + TypeManager typeManager, + FunctionMetadataManager functionMetadataManager, + FunctionTranslator functionTranslator, + StandardFunctionResolution standardFunctionResolution, + String quote) { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionTranslator = requireNonNull(functionTranslator, "functionTranslator is null"); + this.standardFunctionResolution = requireNonNull(standardFunctionResolution, "standardFunctionResolution is null"); this.quote = requireNonNull(quote, "quote is null"); } @@ -108,7 +123,47 @@ public TranslatedExpression translateCall(CallExpression call, M FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(call.getFunctionHandle()); try { - return functionTranslator.translate(functionMetadata, call, translatedExpressions); + TranslatedExpression translate = functionTranslator.translate(functionMetadata, call, translatedExpressions); + if (translate.getTranslated().isPresent()) { + return translate; + } + + FunctionHandle functionHandle = call.getFunctionHandle(); + + if (standardFunctionResolution.isCastFunction(functionHandle)) { + return handleCast(call, context, rowExpressionTreeTranslator); + } + + if (standardFunctionResolution.isBetweenFunction(functionHandle)) { + return handleBetween(call, context, rowExpressionTreeTranslator); + } + + Optional operatorTypeOptional = functionMetadata.getOperatorType(); + if (operatorTypeOptional.isPresent()) { + OperatorType operatorType = operatorTypeOptional.get(); + if (operatorType.isArithmeticOperator() || operatorType.isComparisonOperator()) { + if (operatorType == IS_DISTINCT_FROM) { + return untranslated(call); + } + + List translatedArguments = translatedExpressions.stream() + .map(TranslatedExpression::getTranslated) + .map(Optional::get) + .collect(toImmutableList()); + + if (translatedArguments.size() != 2) { + return untranslated(call, translatedExpressions); + } + + List sqlBodies = mergeSqlBodies(translatedArguments); + List variableBindings = mergeVariableBindings(translatedArguments); + String arithmeticOperator = String.format(" %s ", operatorType.getOperator()); + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("%s", Joiner.on(arithmeticOperator).join(sqlBodies)), variableBindings)), + call, + translatedExpressions); + } + } } catch (Throwable t) { // no-op @@ -116,6 +171,52 @@ public TranslatedExpression translateCall(CallExpression call, M return untranslated(call, translatedExpressions); } + private TranslatedExpression handleCast(CallExpression cast, + Map context, + RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + if (cast.getArguments().size() == 1) { + RowExpression input = cast.getArguments().get(0); + Type expectedType = cast.getType(); + if (typeManager.canCoerce(input.getType(), expectedType)) { + return rowExpressionTreeTranslator.rewrite(input, context); + } + } + + return untranslated(cast); + } + + private TranslatedExpression handleBetween(CallExpression between, + Map context, + RowExpressionTreeTranslator> rowExpressionTreeTranslator) + { + if (between.getArguments().size() == 3) { + List> translatedExpressions = between.getArguments().stream() + .map(expression -> rowExpressionTreeTranslator.rewrite(expression, context)) + .collect(toImmutableList()); + + List jdbcExpressions = translatedExpressions.stream() + .map(TranslatedExpression::getTranslated) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + + if (jdbcExpressions.size() < translatedExpressions.size()) { + return untranslated(between, translatedExpressions); + } + + List sqlBodies = mergeSqlBodies(jdbcExpressions); + List variableBindings = mergeVariableBindings(jdbcExpressions); + + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("(%s BETWEEN %s)", sqlBodies.get(0), Joiner.on(" AND ").join(sqlBodies.subList(1, sqlBodies.size()))), variableBindings)), + between, + translatedExpressions); + } + + return untranslated(between); + } + @Override public TranslatedExpression translateSpecialForm(SpecialFormExpression specialForm, Map context, RowExpressionTreeTranslator> rowExpressionTreeTranslator) { @@ -137,6 +238,11 @@ public TranslatedExpression translateSpecialForm(SpecialFormExpr List variableBindings = mergeVariableBindings(jdbcExpressions); switch (specialForm.getForm()) { + case IS_NULL: + return new TranslatedExpression<>( + Optional.of(new JdbcExpression(format("(%s IS NULL)", sqlBodies.get(0)))), + specialForm, + translatedExpressions); case AND: return new TranslatedExpression<>( Optional.of(new JdbcExpression(format("(%s)", Joiner.on(" AND ").join(sqlBodies)), variableBindings)), diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java index 9dcd399b66472..db3a1af1d203a 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcPlanOptimizerProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc.optimization; +import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.plugin.jdbc.JdbcClient; import com.facebook.presto.plugin.jdbc.optimization.function.OperatorTranslators; import com.facebook.presto.spi.ConnectorPlanOptimizer; @@ -31,6 +32,8 @@ public class JdbcPlanOptimizerProvider implements ConnectorPlanOptimizerProvider { + private final JdbcClient jdbcClient; + private final TypeManager typeManager; private final FunctionMetadataManager functionManager; private final StandardFunctionResolution functionResolution; private final DeterminismEvaluator determinismEvaluator; @@ -40,11 +43,14 @@ public class JdbcPlanOptimizerProvider @Inject public JdbcPlanOptimizerProvider( JdbcClient jdbcClient, + TypeManager typeManager, FunctionMetadataManager functionManager, StandardFunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator, ExpressionOptimizer expressionOptimizer) { + this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null"); @@ -54,14 +60,10 @@ public JdbcPlanOptimizerProvider( @Override public Set getLogicalPlanOptimizers() - { - return ImmutableSet.of(); - } - - @Override - public Set getPhysicalPlanOptimizers() { return ImmutableSet.of(new JdbcComputePushdown( + jdbcClient, + typeManager, functionManager, functionResolution, determinismEvaluator, @@ -70,6 +72,12 @@ public Set getPhysicalPlanOptimizers() getFunctionTranslators())); } + @Override + public Set getPhysicalPlanOptimizers() + { + return ImmutableSet.of(); + } + private Set> getFunctionTranslators() { return ImmutableSet.of(OperatorTranslators.class); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcQueryGeneratorContext.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcQueryGeneratorContext.java new file mode 100644 index 0000000000000..0e68ea82fefa1 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcQueryGeneratorContext.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class JdbcQueryGeneratorContext +{ + private Optional additionalPredicate; + private Optional> sortOrder; + private OptionalLong limit; + + public JdbcQueryGeneratorContext() + { + this(Optional.empty(), Optional.empty(), OptionalLong.empty()); + } + + @JsonCreator + public JdbcQueryGeneratorContext( + @JsonProperty("additionalPredicate") Optional additionalPredicate, + @JsonProperty("sortOrder") Optional> sortOrder, + @JsonProperty("limit") OptionalLong limit) + { + this.additionalPredicate = requireNonNull(additionalPredicate, "additionalPredicate is null"); + this.sortOrder = requireNonNull(sortOrder, "sortOrder is null"); + this.limit = requireNonNull(limit, "limit is null"); + } + + JdbcQueryGeneratorContext withLimit(OptionalLong limit) + { + checkState(!hasLimit(), "Limit already exists. Jdbc datasources doesn't support limit on top of another limit"); + this.limit = limit; + return this; + } + + JdbcQueryGeneratorContext withTopN(Optional> sortOrder, OptionalLong limit) + { + checkState(!hasLimit(), "Limit already exists. Jdbc datasources doesn't support limit on top of another limit"); + this.sortOrder = sortOrder; + this.limit = limit; + return this; + } + + JdbcQueryGeneratorContext withFilter(Optional additionalPredicate) + { + checkState(!hasFilter(), "The filter has been set!"); + this.additionalPredicate = additionalPredicate; + return this; + } + + @JsonProperty + public Optional getAdditionalPredicate() + { + return additionalPredicate; + } + + @JsonProperty + public OptionalLong getLimit() + { + return limit; + } + + @JsonProperty + public Optional> getSortOrder() + { + return sortOrder; + } + + private boolean hasLimit() + { + return limit.isPresent(); + } + + private boolean hasFilter() + { + return additionalPredicate.isPresent(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + JdbcQueryGeneratorContext context = (JdbcQueryGeneratorContext) o; + return Objects.equals(additionalPredicate, context.additionalPredicate) && + Objects.equals(sortOrder, context.sortOrder) && + Objects.equals(limit, context.limit); + } + + @Override + public int hashCode() + { + return Objects.hash(additionalPredicate, sortOrder, limit); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("additionalPredicate", additionalPredicate) + .add("sortOrder", sortOrder) + .add("limit", limit) + .toString(); + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcSortItem.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcSortItem.java new file mode 100644 index 0000000000000..623be3e2eeeae --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcSortItem.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.optimization; + +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +@Immutable +public final class JdbcSortItem +{ + private final JdbcColumnHandle column; + private final SortOrder sortOrder; + + @JsonCreator + public JdbcSortItem(JdbcColumnHandle column, SortOrder sortOrder) + { + this.column = requireNonNull(column, "column is null"); + this.sortOrder = requireNonNull(sortOrder, "sortOrder is null"); + } + + @JsonProperty + public JdbcColumnHandle getColumn() + { + return column; + } + + @JsonProperty + public SortOrder getSortOrder() + { + return sortOrder; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JdbcSortItem that = (JdbcSortItem) o; + return sortOrder == that.sortOrder && + Objects.equals(column, that.column); + } + + @Override + public int hashCode() + { + return Objects.hash(column, sortOrder); + } + + @Override + public String toString() + { + return column + " " + sortOrder; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java index 26931f15ee790..02c546746ce83 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/JdbcTranslationUtil.java @@ -43,7 +43,6 @@ public static List mergeSqlBodies(List jdbcExpressions) { return jdbcExpressions.stream() .map(JdbcExpression::getExpression) - .map(sql -> '(' + sql + ')') .collect(toImmutableList()); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java index 14bafd3326062..3c670b85e7980 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/function/OperatorTranslators.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.function.OperatorType.ADD; +import static com.facebook.presto.common.function.OperatorType.CAST; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.common.function.OperatorType.SUBTRACT; @@ -66,4 +67,25 @@ public static JdbcExpression not(@SqlType(StandardTypes.BOOLEAN) JdbcExpression { return new JdbcExpression(String.format("(NOT(%s))", expression.getExpression()), expression.getBoundConstantValues()); } + + @ScalarFunction("like") + @SqlType(StandardTypes.BOOLEAN) + public static JdbcExpression like(@SqlType(StandardTypes.VARCHAR) JdbcExpression left, @SqlType("LikePattern") JdbcExpression right) + { + return new JdbcExpression(infixOperation("LIKE", left, right), forwardBindVariables(left, right)); + } + + @ScalarFunction("like_pattern") + @SqlType("LikePattern") + public static JdbcExpression likePattern(@SqlType(StandardTypes.VARCHAR) JdbcExpression left, @SqlType(StandardTypes.VARCHAR) JdbcExpression right) + { + return new JdbcExpression(String.format("%s ESCAPE %s", left.getExpression(), right.getExpression()), forwardBindVariables(left, right)); + } + + @ScalarOperator(CAST) + @SqlType("LikePattern") + public static JdbcExpression cast(@SqlType(StandardTypes.VARCHAR) JdbcExpression expression) + { + return expression; + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java index 35013909158b9..a780c2f32b03f 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java @@ -91,8 +91,8 @@ public void testGetColumnHandles() "value", new JdbcColumnHandle(CONNECTOR_ID, "VALUE", JDBC_BIGINT, BIGINT, true))); // unknown table - unknownTableColumnHandle(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("unknown", "unknown"), "unknown", "unknown", "unknown")); - unknownTableColumnHandle(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "example", "unknown")); + unknownTableColumnHandle(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("unknown", "unknown"), "unknown", "unknown", "unknown", Optional.empty())); + unknownTableColumnHandle(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "example", "unknown", Optional.empty())); } private void unknownTableColumnHandle(JdbcTableHandle tableHandle) @@ -125,9 +125,9 @@ public void getTableMetadata() new ColumnMetadata("va%ue", BIGINT))); // unknown tables should produce null - unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("u", "numbers"), null, "unknown", "unknown")); - unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "example", "unknown")); - unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "unknown", "numbers")); + unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("u", "numbers"), null, "unknown", "unknown", Optional.empty())); + unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "example", "unknown", Optional.empty())); + unknownTableMetadata(new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName("example", "numbers"), null, "unknown", "numbers", Optional.empty())); } private void unknownTableMetadata(JdbcTableHandle tableHandle) diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java index 8947b48c2facb..64ec7a7828cec 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcRecordSetProvider.java @@ -32,7 +32,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -181,7 +180,7 @@ public void testTupleDomain() private RecordCursor getCursor(JdbcTableHandle jdbcTableHandle, List columns, TupleDomain domain) { - JdbcTableLayoutHandle layoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, domain, Optional.empty()); + JdbcTableLayoutHandle layoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, domain); ConnectorSplitSource splits = jdbcClient.getSplits(IDENTITY, layoutHandle); JdbcSplit split = (JdbcSplit) getOnlyElement(getFutureValue(splits.getNextBatch(NOT_PARTITIONED, 1000)).getSplits()); diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcTableHandle.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcTableHandle.java index 3c3393d361c45..088c6befd0eab 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcTableHandle.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcTableHandle.java @@ -17,6 +17,8 @@ import com.facebook.presto.spi.SchemaTableName; import org.testng.annotations.Test; +import java.util.Optional; + import static com.facebook.presto.plugin.jdbc.MetadataUtil.TABLE_CODEC; import static com.facebook.presto.plugin.jdbc.MetadataUtil.assertJsonRoundTrip; @@ -25,7 +27,7 @@ public class TestJdbcTableHandle @Test public void testJsonRoundTrip() { - assertJsonRoundTrip(TABLE_CODEC, new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable")); + assertJsonRoundTrip(TABLE_CODEC, new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable", Optional.empty())); } @Test @@ -33,20 +35,20 @@ public void testEquivalence() { EquivalenceTester.equivalenceTester() .addEquivalentGroup( - new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX")) + new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX", Optional.empty())) .addEquivalentGroup( - new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable"), - new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX")) + new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorIdX", new SchemaTableName("schema", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX", Optional.empty())) .addEquivalentGroup( - new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable"), - new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX")) + new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalogX", "jdbcSchema", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchemaX", "jdbcTable", Optional.empty()), + new JdbcTableHandle("connectorId", new SchemaTableName("schemaX", "table"), "jdbcCatalog", "jdbcSchema", "jdbcTableX", Optional.empty())) .check(); } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingBaseJdbcClient.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingBaseJdbcClient.java new file mode 100644 index 0000000000000..b9fe04556fca7 --- /dev/null +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingBaseJdbcClient.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; + +import java.util.List; + +public class TestingBaseJdbcClient + extends BaseJdbcClient +{ + public TestingBaseJdbcClient(JdbcConnectorId connectorId, BaseJdbcConfig config, + String identifierQuote, ConnectionFactory connectionFactory) + { + super(connectorId, config, identifierQuote, connectionFactory); + } + + @Override + public boolean supportsLimit() + { + return true; + } + + @Override + public boolean supportsTopN(List sortItems) + { + return true; + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java index 61d8c6c792e27..70da1eb6b4ade 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java @@ -101,7 +101,7 @@ public JdbcSplit getSplit(String schemaName, String tableName) { JdbcIdentity identity = JdbcIdentity.from(session); JdbcTableHandle jdbcTableHandle = jdbcClient.getTableHandle(identity, new SchemaTableName(schemaName, tableName)); - JdbcTableLayoutHandle jdbcLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.all(), Optional.empty()); + JdbcTableLayoutHandle jdbcLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.all()); ConnectorSplitSource splits = jdbcClient.getSplits(identity, jdbcLayoutHandle); return (JdbcSplit) getOnlyElement(getFutureValue(splits.getNextBatch(NOT_PARTITIONED, 1000)).getSplits()); } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingJdbcTypeHandle.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingJdbcTypeHandle.java index 55b8b377beb62..b2a708108766d 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingJdbcTypeHandle.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingJdbcTypeHandle.java @@ -14,25 +14,26 @@ package com.facebook.presto.plugin.jdbc; import java.sql.Types; +import java.util.Optional; public final class TestingJdbcTypeHandle { private TestingJdbcTypeHandle() {} - public static final JdbcTypeHandle JDBC_BOOLEAN = new JdbcTypeHandle(Types.BOOLEAN, 1, 0); + public static final JdbcTypeHandle JDBC_BOOLEAN = new JdbcTypeHandle(Types.BOOLEAN, Optional.of("boolean"), 1, 0); - public static final JdbcTypeHandle JDBC_SMALLINT = new JdbcTypeHandle(Types.SMALLINT, 1, 0); - public static final JdbcTypeHandle JDBC_TINYINT = new JdbcTypeHandle(Types.TINYINT, 2, 0); - public static final JdbcTypeHandle JDBC_INTEGER = new JdbcTypeHandle(Types.INTEGER, 4, 0); - public static final JdbcTypeHandle JDBC_BIGINT = new JdbcTypeHandle(Types.BIGINT, 8, 0); + public static final JdbcTypeHandle JDBC_SMALLINT = new JdbcTypeHandle(Types.SMALLINT, Optional.of("smallint"), 1, 0); + public static final JdbcTypeHandle JDBC_TINYINT = new JdbcTypeHandle(Types.TINYINT, Optional.of("tinyint"), 2, 0); + public static final JdbcTypeHandle JDBC_INTEGER = new JdbcTypeHandle(Types.INTEGER, Optional.of("integer"), 4, 0); + public static final JdbcTypeHandle JDBC_BIGINT = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), 8, 0); - public static final JdbcTypeHandle JDBC_REAL = new JdbcTypeHandle(Types.REAL, 8, 0); - public static final JdbcTypeHandle JDBC_DOUBLE = new JdbcTypeHandle(Types.DOUBLE, 8, 0); + public static final JdbcTypeHandle JDBC_REAL = new JdbcTypeHandle(Types.REAL, Optional.of("real"), 8, 0); + public static final JdbcTypeHandle JDBC_DOUBLE = new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 8, 0); - public static final JdbcTypeHandle JDBC_CHAR = new JdbcTypeHandle(Types.CHAR, 10, 0); - public static final JdbcTypeHandle JDBC_VARCHAR = new JdbcTypeHandle(Types.VARCHAR, 10, 0); + public static final JdbcTypeHandle JDBC_CHAR = new JdbcTypeHandle(Types.CHAR, Optional.of("char"), 10, 0); + public static final JdbcTypeHandle JDBC_VARCHAR = new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), 10, 0); - public static final JdbcTypeHandle JDBC_DATE = new JdbcTypeHandle(Types.DATE, 8, 0); - public static final JdbcTypeHandle JDBC_TIME = new JdbcTypeHandle(Types.TIME, 4, 0); - public static final JdbcTypeHandle JDBC_TIMESTAMP = new JdbcTypeHandle(Types.TIMESTAMP, 8, 0); + public static final JdbcTypeHandle JDBC_DATE = new JdbcTypeHandle(Types.DATE, Optional.of("date"), 8, 0); + public static final JdbcTypeHandle JDBC_TIME = new JdbcTypeHandle(Types.TIME, Optional.of("time"), 4, 0); + public static final JdbcTypeHandle JDBC_TIMESTAMP = new JdbcTypeHandle(Types.TIMESTAMP, Optional.of("timestamp"), 8, 0); } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java index 27438ae8bb4b5..831ce0629be60 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/optimization/TestJdbcComputePushdown.java @@ -14,6 +14,8 @@ package com.facebook.presto.plugin.jdbc.optimization; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.PlanNodeStatsEstimate; @@ -22,10 +24,14 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.plugin.jdbc.BaseJdbcConfig; +import com.facebook.presto.plugin.jdbc.DriverConnectionFactory; +import com.facebook.presto.plugin.jdbc.JdbcClient; import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; +import com.facebook.presto.plugin.jdbc.JdbcConnectorId; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; import com.facebook.presto.plugin.jdbc.JdbcTableLayoutHandle; -import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; +import com.facebook.presto.plugin.jdbc.TestingBaseJdbcClient; import com.facebook.presto.plugin.jdbc.optimization.function.OperatorTranslators; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; @@ -33,11 +39,23 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlFunctionId; +import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.LimitNode; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TopNNode; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -45,6 +63,7 @@ import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; import com.facebook.presto.sql.planner.assertions.MatchResult; import com.facebook.presto.sql.planner.assertions.Matcher; import com.facebook.presto.sql.planner.assertions.PlanAssert; @@ -59,23 +78,51 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import io.airlift.slice.Slices; +import org.h2.Driver; import org.testng.annotations.Test; -import java.sql.Types; import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; +import java.util.Properties; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT; +import static com.facebook.presto.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BOOLEAN; +import static com.facebook.presto.plugin.jdbc.TestingJdbcTypeHandle.JDBC_CHAR; +import static com.facebook.presto.plugin.jdbc.TestingJdbcTypeHandle.JDBC_DOUBLE; +import static com.facebook.presto.plugin.jdbc.TestingJdbcTypeHandle.JDBC_VARCHAR; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.immutableEntry; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; @@ -87,9 +134,21 @@ public class TestJdbcComputePushdown private static final String CATALOG_NAME = "Jdbc"; private static final String CONNECTOR_ID = new ConnectorId(CATALOG_NAME).toString(); + private static List jdbcColumnHandles = Lists.newArrayList( + new JdbcColumnHandle(CONNECTOR_ID, "l_orderkey", JDBC_BIGINT, BIGINT, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_partkey", JDBC_BIGINT, BIGINT, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_quantity", JDBC_DOUBLE, DOUBLE, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_extendedprice", JDBC_DOUBLE, DOUBLE, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_returnflag", JDBC_CHAR, createCharType(1), false), + new JdbcColumnHandle(CONNECTOR_ID, "l_shipdate", JDBC_VARCHAR, VARCHAR, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_commitdate", JDBC_VARCHAR, VARCHAR, false), + new JdbcColumnHandle(CONNECTOR_ID, "l_receiptdate", JDBC_VARCHAR, VARCHAR, false)); + private final TestingRowExpressionTranslator sqlToRowExpressionTranslator; private final JdbcComputePushdown jdbcComputePushdown; + private final FunctionAndTypeManager typeManager = createTestFunctionAndTypeManager(); + private final JdbcClient jdbcClient; public TestJdbcComputePushdown() { @@ -98,7 +157,16 @@ public TestJdbcComputePushdown() StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager); DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); + String connectionUrl = "jdbc:h2:mem:test" + System.nanoTime(); + jdbcClient = new TestingBaseJdbcClient( + new JdbcConnectorId(CONNECTOR_ID), + new BaseJdbcConfig(), + "\"", + new DriverConnectionFactory(new Driver(), connectionUrl, Optional.empty(), Optional.empty(), new Properties())); + this.jdbcComputePushdown = new JdbcComputePushdown( + jdbcClient, + typeManager, functionAndTypeManager, functionResolution, determinismEvaluator, @@ -119,8 +187,10 @@ public void testJdbcComputePushdownAll() Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); - JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.of(new JdbcExpression("(('c1' + 'c2') - 'c2')"))); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("(('c1' + 'c2') - 'c2')")); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -141,11 +211,12 @@ public void testJdbcComputePushdownBooleanOperations() Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("((((('c1' + 'c2') - 'c2') <> 'c2') OR ('c2' = 'c1')) AND ('c1' <> 'c2'))")); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.of(new JdbcExpression("((((((('c1' + 'c2') - 'c2') <> 'c2')) OR (('c2' = 'c1')))) AND (('c1' <> 'c2')))"))); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -155,7 +226,7 @@ public void testJdbcComputePushdownBooleanOperations() } @Test - public void testJdbcComputePushdownUnsupported() + public void testJdbcComputePushdownGreaterThen() { String table = "test_table"; String schema = "test_schema"; @@ -166,15 +237,17 @@ public void testJdbcComputePushdownUnsupported() Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); - // Test should expect an empty entry for translatedSql since > is an unsupported function currently in the optimizer - JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.empty()); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("('c1' + 'c2') > 'c2'")); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); - assertPlanMatch(actual, PlanMatchPattern.filter( - expression, - JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); } @Test @@ -189,11 +262,13 @@ public void testJdbcComputePushdownWithConstants() Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet()); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(3L, INTEGER)))); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.of(new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(Long.valueOf(3), INTEGER))))); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -211,12 +286,13 @@ public void testJdbcComputePushdownNotOperator() RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); PlanNode original = filter(jdbcTableScan(schema, table, BOOLEAN, "c1", "c2"), rowExpression); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("('c1' AND (NOT('c2')))")); Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.of(new JdbcExpression("(('c1') AND ((NOT('c2'))))"))); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -235,12 +311,13 @@ public void testJdbcComputePartialPushdown() RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(3L, INTEGER)))); Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.of(new JdbcExpression("((('c1' + 'c2') = ?))", ImmutableList.of(new ConstantExpression(3L, INTEGER))))); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -261,12 +338,13 @@ public void testJdbcComputePartialPushdownWithOrOperator() RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("((('c1' + 'c2') = ?) OR ('c1' <> 'c2'))", ImmutableList.of(new ConstantExpression(3L, INTEGER)))); Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.of(new JdbcExpression("((((('c1' + 'c2') = ?)) OR (('c1' <> 'c2'))))", ImmutableList.of(new ConstantExpression(3L, INTEGER))))); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -282,17 +360,16 @@ public void testJdbcComputeNoPushdown() String schema = "test_schema"; // no filter can be pushed down - String expression = "CAST(c1 AS varchar(1024)) = '123' and ((c1 - c2) > c2 or c1 <> c2)"; + String expression = "CAST(c1 AS varchar(1024)) = '123' and ((c1 - c2) > c2 or CAST(c2 AS varchar(1024)) = '456')"; TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, Optional.empty()); JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( jdbcTableHandle, - TupleDomain.none(), - Optional.empty()); + TupleDomain.none()); ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); @@ -301,6 +378,286 @@ public void testJdbcComputeNoPushdown() JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns))); } + @Test + public void testJdbcComputeLimitPushdownWithoutFilter() + { + String table = "test_table"; + String schema = "test_schema"; + + PlanNode original = limit(8, jdbcTableScan(schema, table, BIGINT, "c1", "c2")); + + Optional context = getJdbcQueryGeneratorContext(Optional.empty(), OptionalLong.of(8)); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); + } + + @Test + public void testJdbcComputeLimitPushdownWithFilter() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "c1 + c2 = 3"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + FilterNode filter = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + PlanNode original = limit(8, filter); + + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(3L, INTEGER))), + 8L); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)); + } + + @Test + public void testJdbcComputeLimitNotPushdownWithFilter() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "CAST(c1 AS varchar(1024)) = '123' and c1 + c2 = 3"; + TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT)); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + FilterNode filter = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression); + PlanNode original = limit(8, filter); + + Optional context = getJdbcQueryGeneratorContext( + new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(3L, INTEGER)))); + Set columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + + assertPlanMatch(actual, PlanMatchPattern.limit(8L, + PlanMatchPattern.filter( + "CAST(c1 AS varchar(1024)) = '123'", + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)))); + } + + @Test + public void testSimpleOrderByLimitPushdown() + { + ImmutableList expected = ImmutableList.of(new JdbcSortItem(getColumnHandleForVariable("l_orderkey", jdbcColumnHandles), SortOrder.DESC_NULLS_FIRST)); + simpleOrderByLimitPushdownCommon(ImmutableList.of("l_orderkey"), ImmutableList.of(false), expected); + + expected = ImmutableList.of(new JdbcSortItem(getColumnHandleForVariable("l_orderkey", jdbcColumnHandles), SortOrder.DESC_NULLS_FIRST), + new JdbcSortItem(getColumnHandleForVariable("l_extendedprice", jdbcColumnHandles), SortOrder.ASC_NULLS_FIRST)); + simpleOrderByLimitPushdownCommon(ImmutableList.of("l_orderkey", "l_extendedprice"), ImmutableList.of(false, true), expected); + } + + private void simpleOrderByLimitPushdownCommon(List orderingColumns, List ascending, List expected) + { + String table = "test_table"; + String schema = "test_schema"; + + TableScanNode tableScanNode = jdbcTableScan(); + PlanNode original = topN(8, orderingColumns, ascending, tableScanNode); + + Optional context = getJdbcQueryGeneratorContext( + Optional.empty(), + Optional.of(expected), + OptionalLong.of(8)); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, Sets.newHashSet(jdbcColumnHandles))); + } + + @Test + public void testOrderByLimitWithFilterPushdown() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "l_orderkey = 3 and l_commitdate between '2021-07-19' and '2021-07-20'"; + TableScanNode tableScanNode = jdbcTableScan(); + TypeProvider typeProvider = TypeProvider.fromVariables(tableScanNode.getOutputVariables()); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider); + + FilterNode filter = filter(tableScanNode, rowExpression); + PlanNode original = topN(8, ImmutableList.of("l_orderkey", "l_extendedprice"), ImmutableList.of(false, true), filter); + + List constantBindValues = ImmutableList.of(new ConstantExpression(3L, INTEGER), + new ConstantExpression(Slices.utf8Slice("2021-07-19"), createVarcharType(10)), + new ConstantExpression(Slices.utf8Slice("2021-07-20"), createVarcharType(10))); + + List sortOrder = ImmutableList.of(new JdbcSortItem(getColumnHandleForVariable("l_orderkey", jdbcColumnHandles), SortOrder.DESC_NULLS_FIRST), + new JdbcSortItem(getColumnHandleForVariable("l_extendedprice", jdbcColumnHandles), SortOrder.ASC_NULLS_FIRST)); + + Optional context = getJdbcQueryGeneratorContext( + Optional.of(new JdbcExpression("(('l_orderkey' = ?) AND ('l_commitdate' BETWEEN ? AND ?))", constantBindValues)), + Optional.of(sortOrder), + OptionalLong.of(8)); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + assertPlanMatch(actual, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, Sets.newHashSet(jdbcColumnHandles))); + } + + @Test + public void testOrderByLimitNotPushdown() + { + String table = "test_table"; + String schema = "test_schema"; + + TableScanNode tableScanNode = jdbcTableScan(); + Map assignments = tableScanNode.getOutputVariables().stream() + .map(v -> immutableEntry(v, v)) + .collect(toMap(Map.Entry::getKey, + Map.Entry::getValue, + (val1, val2) -> val1, + LinkedHashMap::new)); + + Map functions = ImmutableMap.of(getSqlFunctionId(), getSqlInvokedFunction()); + + FunctionHandle functionHandle = typeManager.resolveFunction( + Optional.of(functions), + Optional.empty(), + getSqlFunctionId().getFunctionName(), + fromTypes(VARCHAR, BIGINT, BIGINT)); + + List arguments = Lists.newArrayList(new VariableReferenceExpression("l_commitdate", VARCHAR), + new ConstantExpression(1L, BIGINT), + new ConstantExpression(2L, BIGINT)); + assignments.put(new VariableReferenceExpression("substring", VARCHAR), + new CallExpression("substring", functionHandle, VARCHAR, arguments)); + + ProjectNode project = project(Assignments.copyOf(assignments), tableScanNode); + PlanNode original = topN(8, ImmutableList.of("substring", "l_orderkey"), ImmutableList.of(false, true), project); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, Optional.empty()); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + + Map assignmentsMap = tableScanNode.getOutputVariables().stream().map(v -> + immutableEntry(v.getName(), PlanMatchPattern.expression(v.getName()))) + .collect(toMap(Map.Entry::getKey, + Map.Entry::getValue, + (val1, val2) -> val1, + LinkedHashMap::new)); + + assignmentsMap.put("substring", PlanMatchPattern.expression("substring(l_commitdate, 1, 2)")); + List orderBy = ImmutableList.of(sort("substring", DESCENDING, FIRST), sort("l_orderkey", ASCENDING, FIRST)); + + assertPlanMatch(actual, + PlanMatchPattern.topN(8, orderBy, + PlanMatchPattern.project(assignmentsMap, + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, Sets.newHashSet(jdbcColumnHandles))))); + } + + @Test + public void testOrderByLimitWithFilterNotPushdown() + { + String table = "test_table"; + String schema = "test_schema"; + + String expression = "l_orderkey = 3 and substring(l_commitdate, 1, 2) = '20'"; + TableScanNode tableScanNode = jdbcTableScan(); + TypeProvider typeProvider = TypeProvider.fromVariables(tableScanNode.getOutputVariables()); + RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider, createSessionWithTempFunctionSubstring()); + + FilterNode filter = filter(tableScanNode, rowExpression); + PlanNode original = topN(8, ImmutableList.of("l_orderkey", "l_extendedprice"), ImmutableList.of(false, true), filter); + Optional context = getJdbcQueryGeneratorContext( + Optional.of(new JdbcExpression("('l_orderkey' = ?)", ImmutableList.of(new ConstantExpression(3L, INTEGER)))), + Optional.empty(), + OptionalLong.empty()); + + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, context); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle( + jdbcTableHandle, + TupleDomain.none()); + + ConnectorSession session = new TestingConnectorSession(ImmutableList.of()); + PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR); + + List orderBy = ImmutableList.of(sort("l_orderkey", DESCENDING, FIRST), sort("l_extendedprice", ASCENDING, FIRST)); + assertPlanMatch(actual, + PlanMatchPattern.topN(8, orderBy, + PlanMatchPattern.filter("substring(l_commitdate, 1, 2) = '20'", + JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, Sets.newHashSet(jdbcColumnHandles))))); + } + + private Session createSessionWithTempFunctionSubstring() + { + return testSessionBuilder() + .addSessionFunction(getSqlFunctionId(), getSqlInvokedFunction()) + .build(); + } + + private SqlFunctionId getSqlFunctionId() + { + return new SqlFunctionId(QualifiedObjectName.valueOf("presto.default.substring"), + ImmutableList.of(parseTypeSignature("varchar"), parseTypeSignature("bigint"), parseTypeSignature("bigint"))); + } + + private SqlInvokedFunction getSqlInvokedFunction() + { + return new SqlInvokedFunction( + getSqlFunctionId().getFunctionName(), + ImmutableList.of(new Parameter("x", parseTypeSignature("varchar")), + new Parameter("y", parseTypeSignature("bigint")), + new Parameter("z", parseTypeSignature("bigint"))), + parseTypeSignature("varchar"), + "", + RoutineCharacteristics.builder().build(), + "", + notVersioned()); + } + + private Optional getJdbcQueryGeneratorContext(JdbcExpression jdbcExpression, Long limit) + { + return getJdbcQueryGeneratorContext(Optional.of(jdbcExpression), OptionalLong.of(limit)); + } + + private Optional getJdbcQueryGeneratorContext(JdbcExpression jdbcExpression) + { + return getJdbcQueryGeneratorContext(Optional.of(jdbcExpression), OptionalLong.empty()); + } + + private Optional getJdbcQueryGeneratorContext(Optional jdbcExpression, OptionalLong limit) + { + return getJdbcQueryGeneratorContext(jdbcExpression, Optional.empty(), limit); + } + + private Optional getJdbcQueryGeneratorContext(Optional jdbcExpression, Optional> sortOrder, OptionalLong limit) + { + return Optional.of(new JdbcQueryGeneratorContext(jdbcExpression, sortOrder, limit)); + } + private Set> getFunctionTranslators() { return ImmutableSet.of(OperatorTranslators.class); @@ -313,12 +670,19 @@ private static VariableReferenceExpression newVariable(String name, Type type) private static JdbcColumnHandle integerJdbcColumnHandle(String name) { - return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BIGINT, 10, 0), BIGINT, false); + return new JdbcColumnHandle(CONNECTOR_ID, name, JDBC_BIGINT, BIGINT, false); } private static JdbcColumnHandle booleanJdbcColumnHandle(String name) { - return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BOOLEAN, 1, 0), BOOLEAN, false); + return new JdbcColumnHandle(CONNECTOR_ID, name, JDBC_BOOLEAN, BOOLEAN, false); + } + + private static JdbcColumnHandle getColumnHandleForVariable(String name, List jdbcColumnHandles) + { + return jdbcColumnHandles.stream() + .filter(h -> h.getColumnName().equalsIgnoreCase(name)) + .findFirst().orElseThrow(() -> new IllegalArgumentException("Cannot find jdbcColumnHandle " + name)); } private static JdbcColumnHandle getColumnHandleForVariable(String name, Type type) @@ -343,8 +707,8 @@ private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected, private TableScanNode jdbcTableScan(String schema, String table, Type type, String... columnNames) { - JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table); - JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none(), Optional.empty()); + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, Optional.empty()); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none()); TableHandle tableHandle = new TableHandle(new ConnectorId(CATALOG_NAME), jdbcTableHandle, new ConnectorTransactionHandle() {}, Optional.of(jdbcTableLayoutHandle)); return PLAN_BUILDER.tableScan( @@ -355,11 +719,53 @@ private TableScanNode jdbcTableScan(String schema, String table, Type type, Stri .collect(toMap(identity(), entry -> getColumnHandleForVariable(entry.getName(), type)))); } + private TableScanNode jdbcTableScan() + { + String table = "test_table"; + String schema = "test_schema"; + JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table, Optional.empty()); + JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(jdbcTableHandle, TupleDomain.none()); + TableHandle tableHandle = new TableHandle(new ConnectorId(CATALOG_NAME), jdbcTableHandle, new ConnectorTransactionHandle() {}, Optional.of(jdbcTableLayoutHandle)); + + return PLAN_BUILDER.tableScan( + tableHandle, + jdbcColumnHandles.stream().map(column -> newVariable(column.getColumnName(), column.getColumnType())).collect(toImmutableList()), + jdbcColumnHandles.stream() + .map(column -> newVariable(column.getColumnName(), column.getColumnType())) + .collect(toMap(identity(), entry -> getColumnHandleForVariable(entry.getName(), jdbcColumnHandles)))); + } + + private ProjectNode project(Assignments assignments, PlanNode source) + { + return PLAN_BUILDER.project(assignments, source); + } + private FilterNode filter(PlanNode source, RowExpression predicate) { return PLAN_BUILDER.filter(predicate, source); } + private LimitNode limit(long count, PlanNode source) + { + return PLAN_BUILDER.limit(count, source); + } + + private TopNNode topN(long count, List orderingColumns, List ascending, PlanNode source) + { + ImmutableList ordering = IntStream.range(0, orderingColumns.size()) + .boxed() + .map(i -> new Ordering(variable(source.getOutputVariables(), orderingColumns.get(i)), ascending.get(i) ? SortOrder.ASC_NULLS_FIRST : SortOrder.DESC_NULLS_FIRST)) + .collect(toImmutableList()); + + return new TopNNode(PLAN_BUILDER.getIdAllocator().getNextId(), source, count, new OrderingScheme(ordering), TopNNode.Step.PARTIAL); + } + + private VariableReferenceExpression variable(List outputVariables, String name) + { + return outputVariables.stream().filter(v -> v.getName().equals(name)) + .findFirst().orElseThrow(() -> new IllegalArgumentException("Cannot find variable " + name)); + } + private static final class JdbcTableScanMatcher implements Matcher { @@ -390,16 +796,20 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses TableScanNode tableScanNode = (TableScanNode) node; JdbcTableLayoutHandle layoutHandle = (JdbcTableLayoutHandle) tableScanNode.getTable().getLayout().get(); + + JdbcTableHandle expectedTableHandle = jdbcTableLayoutHandle.getTable(); + JdbcTableHandle actualTableHandle = layoutHandle.getTable(); + if (jdbcTableLayoutHandle.getTable().equals(layoutHandle.getTable()) && jdbcTableLayoutHandle.getTupleDomain().equals(layoutHandle.getTupleDomain()) - && ((!jdbcTableLayoutHandle.getAdditionalPredicate().isPresent() && !layoutHandle.getAdditionalPredicate().isPresent()) - || jdbcTableLayoutHandle.getAdditionalPredicate().get().getExpression().equals(layoutHandle.getAdditionalPredicate().get().getExpression()))) { + && ((!expectedTableHandle.getContext().isPresent() && !actualTableHandle.getContext().isPresent()) + || expectedTableHandle.getContext().get().equals(actualTableHandle.getContext().get()))) { return MatchResult.match( SymbolAliases.builder().putAll( - columns.stream() - .map(column -> ((JdbcColumnHandle) column).getColumnName()) - .collect(toMap(identity(), SymbolReference::new))) - .build()); + columns.stream() + .map(column -> ((JdbcColumnHandle) column).getColumnName()) + .collect(toMap(identity(), SymbolReference::new))) + .build()); } return MatchResult.NO_MATCH; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java index e405d694d7c8d..519888e7972a0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestingRowExpressionTranslator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql; +import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; @@ -63,6 +64,11 @@ public RowExpression translateAndOptimize(Expression expression, TypeProvider ty return translateAndOptimize(expression, getExpressionTypes(expression, typeProvider)); } + public RowExpression translateAndOptimize(Expression expression, TypeProvider typeProvider, Session session) + { + return translateAndOptimize(expression, getExpressionTypes(expression, typeProvider, session), session); + } + public RowExpression translate(String sql, Map types) { return translate(ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql)), TypeProvider.viewOf(types)); @@ -80,9 +86,14 @@ public RowExpression translate(Expression expression, TypeProvider typeProvider) public RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), TEST_SESSION); + return translateAndOptimize(expression, types, TEST_SESSION); + } + + public RowExpression translateAndOptimize(Expression expression, Map, Type> types, Session session) + { + RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session); RowExpressionOptimizer optimizer = new RowExpressionOptimizer(metadata); - return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession()); + return optimizer.optimize(rowExpression, OPTIMIZED, session.toConnectorSession()); } Expression simplifyExpression(Expression expression) @@ -96,10 +107,15 @@ Expression simplifyExpression(Expression expression) } private Map, Type> getExpressionTypes(Expression expression, TypeProvider typeProvider) + { + return getExpressionTypes(expression, typeProvider, TEST_SESSION); + } + + private Map, Type> getExpressionTypes(Expression expression, TypeProvider typeProvider, Session session) { ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries( metadata.getFunctionAndTypeManager(), - TEST_SESSION, + session, typeProvider, emptyList(), node -> new IllegalStateException("Unexpected node: %s" + node), diff --git a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java index d3b29fed56801..ec450f88cf566 100644 --- a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java +++ b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.mysql; +import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.plugin.jdbc.BaseJdbcClient; @@ -23,6 +24,7 @@ import com.facebook.presto.plugin.jdbc.JdbcConnectorId; import com.facebook.presto.plugin.jdbc.JdbcIdentity; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; @@ -39,8 +41,11 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.Properties; +import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; @@ -57,6 +62,7 @@ import static com.mysql.jdbc.SQLError.SQL_STATE_SYNTAX_ERROR; import static java.lang.String.format; import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.joining; public class MySqlClient extends BaseJdbcClient @@ -93,6 +99,57 @@ private static ConnectionFactory connectionFactory(BaseJdbcConfig config, MySqlC connectionProperties); } + @Override + protected Optional> limitFunction() + { + return Optional.of((sql, limit) -> sql + " LIMIT " + limit); + } + + @Override + public boolean supportsTopN(List sortOrder) + { + for (JdbcSortItem sortItem : sortOrder) { + Type sortItemType = sortItem.getColumn().getColumnType(); + if (sortItemType instanceof CharType || sortItemType instanceof VarcharType) { + // Remote database can be case insensitive. + return false; + } + } + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .flatMap(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String columnSorting = format("%s %s", quoted(sortItem.getColumn().getColumnName()), ordering); + + switch (sortItem.getSortOrder()) { + case ASC_NULLS_FIRST: + // In MySQL ASC implies NULLS FIRST + case DESC_NULLS_LAST: + // In MySQL DESC implies NULLS LAST + return Stream.of(columnSorting); + + case ASC_NULLS_LAST: + return Stream.of( + format("ISNULL(%s) ASC", quoted(sortItem.getColumn().getColumnName())), + columnSorting); + case DESC_NULLS_FIRST: + return Stream.of( + format("ISNULL(%s) DESC", quoted(sortItem.getColumn().getColumnName())), + columnSorting); + } + throw new UnsupportedOperationException("Unsupported sort order: " + sortItem.getSortOrder()); + }) + .collect(joining(", ")); + return format("%s ORDER BY %s LIMIT %s", query, orderBy, limit); + }); + } + @Override protected Collection listSchemas(Connection connection) { diff --git a/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java b/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java index 907ce55552283..a13180efdb31d 100644 --- a/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java +++ b/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java @@ -13,12 +13,16 @@ */ package com.facebook.presto.plugin.postgresql; +import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.plugin.jdbc.BaseJdbcClient; import com.facebook.presto.plugin.jdbc.BaseJdbcConfig; import com.facebook.presto.plugin.jdbc.DriverConnectionFactory; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; import com.facebook.presto.plugin.jdbc.JdbcConnectorId; import com.facebook.presto.plugin.jdbc.JdbcIdentity; +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; @@ -32,12 +36,15 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.List; import java.util.Optional; +import java.util.function.BiFunction; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static java.lang.String.format; +import static java.util.stream.Collectors.joining; public class PostgreSqlClient extends BaseJdbcClient @@ -50,6 +57,58 @@ public PostgreSqlClient(JdbcConnectorId connectorId, BaseJdbcConfig config) super(connectorId, config, "\"", new DriverConnectionFactory(new Driver(), config)); } + @Override + protected Optional> limitFunction() + { + return Optional.of((sql, limit) -> sql + " LIMIT " + limit); + } + + @Override + public boolean supportsTopN(List sortOrder) + { + for (JdbcSortItem sortItem : sortOrder) { + Type sortItemType = sortItem.getColumn().getColumnType(); + if (sortItemType instanceof CharType || sortItemType instanceof VarcharType) { + if (!isCollatable(sortItem.getColumn())) { + return false; + } + } + } + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .map(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String nullsHandling = sortItem.getSortOrder().isNullsFirst() ? "NULLS FIRST" : "NULLS LAST"; + String collation = ""; + if (isCollatable(sortItem.getColumn())) { + collation = "COLLATE \"C\""; + } + return format("%s %s %s %s", quoted(sortItem.getColumn().getColumnName()), collation, ordering, nullsHandling); + }) + .collect(joining(", ")); + return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); + }); + } + + private boolean isCollatable(JdbcColumnHandle column) + { + if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) { + String jdbcTypeName = column.getJdbcTypeHandle().getJdbcTypeName() + .orElseThrow(() -> new PrestoException(JDBC_ERROR, "Type name is missing: " + column.getJdbcTypeHandle())); + // Only char (internally named bpchar)/varchar/text are the built-in collatable types + return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName); + } + + // non-textual types don't have the concept of collation + return false; + } + @Override public PreparedStatement getPreparedStatement(Connection connection, String sql) throws SQLException diff --git a/presto-sqlserver/src/main/java/com/facebook/presto/plugin/sqlserver/SqlServerClient.java b/presto-sqlserver/src/main/java/com/facebook/presto/plugin/sqlserver/SqlServerClient.java index f507246716da1..9b2a4c19a8f13 100644 --- a/presto-sqlserver/src/main/java/com/facebook/presto/plugin/sqlserver/SqlServerClient.java +++ b/presto-sqlserver/src/main/java/com/facebook/presto/plugin/sqlserver/SqlServerClient.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.plugin.sqlserver; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.plugin.jdbc.BaseJdbcClient; import com.facebook.presto.plugin.jdbc.BaseJdbcConfig; import com.facebook.presto.plugin.jdbc.DriverConnectionFactory; @@ -20,6 +23,7 @@ import com.facebook.presto.plugin.jdbc.JdbcConnectorId; import com.facebook.presto.plugin.jdbc.JdbcIdentity; import com.facebook.presto.plugin.jdbc.JdbcTableHandle; +import com.facebook.presto.plugin.jdbc.optimization.JdbcSortItem; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.google.common.base.Joiner; @@ -29,9 +33,14 @@ import java.sql.Connection; import java.sql.SQLException; +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static java.lang.String.format; +import static java.util.stream.Collectors.joining; public class SqlServerClient extends BaseJdbcClient @@ -44,6 +53,57 @@ public SqlServerClient(JdbcConnectorId connectorId, BaseJdbcConfig config) super(connectorId, config, "\"", new DriverConnectionFactory(new SQLServerDriver(), config)); } + @Override + protected Optional> limitFunction() + { + return Optional.of((sql, limit) -> format("SELECT TOP %s * FROM (%s) o", limit, sql)); + } + + @Override + public boolean supportsTopN(List sortOrder) + { + for (JdbcSortItem sortItem : sortOrder) { + Type sortItemType = sortItem.getColumn().getColumnType(); + if (sortItemType instanceof CharType || sortItemType instanceof VarcharType) { + // Remote database can be case insensitive. + return false; + } + } + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .flatMap(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String columnSorting = format("%s %s", quoted(sortItem.getColumn().getColumnName()), ordering); + + switch (sortItem.getSortOrder()) { + case ASC_NULLS_FIRST: + // In SQL Server ASC implies NULLS FIRST + case DESC_NULLS_LAST: + // In SQL Server DESC implies NULLS LAST + return Stream.of(columnSorting); + + case ASC_NULLS_LAST: + return Stream.of( + format("(CASE WHEN %s IS NULL THEN 1 ELSE 0 END) ASC", quoted(sortItem.getColumn().getColumnName())), + columnSorting); + case DESC_NULLS_FIRST: + return Stream.of( + format("(CASE WHEN %s IS NULL THEN 1 ELSE 0 END) DESC", quoted(sortItem.getColumn().getColumnName())), + columnSorting); + } + throw new UnsupportedOperationException("Unsupported sort order: " + sortItem.getSortOrder()); + }) + .collect(joining(", ")); + return format("%s ORDER BY %s OFFSET 0 ROWS FETCH NEXT %s ROWS ONLY", query, orderBy, limit); + }); + } + @Override protected void renameTable(JdbcIdentity identity, String catalogName, SchemaTableName oldTable, SchemaTableName newTable) {