From 170e5b3fe8c82d436fa8eb6b22a98887b8375af9 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 6 Apr 2022 14:07:24 +0200 Subject: [PATCH 1/3] Amend ordering of SQL Server tests execution This is supposed to workaround OOM problems when multiple test instances are initialized. This seems to be prerequisite if we want to add more QueryRunner-based tests. --- plugin/trino-sqlserver/pom.xml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plugin/trino-sqlserver/pom.xml b/plugin/trino-sqlserver/pom.xml index 74a29e6a6f9e..fba1fc188beb 100644 --- a/plugin/trino-sqlserver/pom.xml +++ b/plugin/trino-sqlserver/pom.xml @@ -15,6 +15,14 @@ ${project.parent.basedir} + + + classes From a9aaca5c238a624418271bf93a4ec536659799eb Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 23 Mar 2022 15:05:23 +0100 Subject: [PATCH 2/3] Implement reading table statistics for SQL Server --- .../plugin/sqlserver/SqlServerClient.java | 221 +++++++++- .../sqlserver/SqlServerClientModule.java | 2 + .../plugin/sqlserver/TestSqlServerClient.java | 2 + .../TestSqlServerTableStatistics.java | 415 ++++++++++++++++++ 4 files changed, 639 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index f663dd8e8945..007752e084a1 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -36,6 +36,7 @@ import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSortItem; import io.trino.plugin.jdbc.JdbcSplit; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongReadFunction; @@ -59,6 +60,10 @@ import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -77,6 +82,7 @@ import javax.inject.Inject; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -89,6 +95,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -97,6 +104,9 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.collect.MoreCollectors.toOptional; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT; import static io.airlift.slice.Slices.wrappedBuffer; @@ -191,18 +201,27 @@ public class SqlServerClient .maximumSize(1) .expireAfterWrite(ofMinutes(5))); + private final boolean statisticsEnabled; + private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; private static final int MAX_SUPPORTED_TEMPORAL_PRECISION = 7; @Inject - public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public SqlServerClient( + BaseJdbcConfig config, + SqlServerConfig sqlServerConfig, + JdbcStatisticsConfig statisticsConfig, + ConnectionFactory connectionFactory, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping); requireNonNull(sqlServerConfig, "sqlServerConfig is null"); snapshotIsolationDisabled = sqlServerConfig.isSnapshotIsolationDisabled(); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) @@ -452,6 +471,145 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return readTableStatistics(session, handle); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + String catalog = table.getCatalogName(); + String schema = table.getSchemaName(); + String tableName = table.getTableName(); + + StatisticsDao statisticsDao = new StatisticsDao(handle); + Long tableObjectId = statisticsDao.getTableObjectId(catalog, schema, tableName); + if (tableObjectId == null) { + // Table not found + return TableStatistics.empty(); + } + + Long rowCount = statisticsDao.getRowCount(tableObjectId); + if (rowCount == null) { + // Table disappeared + return TableStatistics.empty(); + } + + if (rowCount == 0) { + return TableStatistics.empty(); + } + + TableStatistics.Builder tableStatistics = TableStatistics.builder(); + tableStatistics.setRowCount(Estimate.of(rowCount)); + + Map columnNameToStatisticsName = getColumnNameToStatisticsName(table, statisticsDao, tableObjectId); + + for (JdbcColumnHandle column : this.getColumns(session, table)) { + String statisticName = columnNameToStatisticsName.get(column.getColumnName()); + if (statisticName == null) { + // No statistic for column + continue; + } + + double averageColumnLength; + long notNullValues = 0; + long nullValues = 0; + long distinctValues = 0; + + try (CallableStatement showStatistics = handle.getConnection().prepareCall("DBCC SHOW_STATISTICS (?, ?)")) { + showStatistics.setString(1, format("%s.%s.%s", catalog, schema, tableName)); + showStatistics.setString(2, statisticName); + + boolean isResultSet = showStatistics.execute(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return a result set"); + try (ResultSet resultSet = showStatistics.getResultSet()) { + checkState(resultSet.next(), "No rows in result set"); + + averageColumnLength = resultSet.getDouble("Average Key Length"); // NULL values are accounted for with length 0 + + checkState(!resultSet.next(), "More than one row in result set"); + } + + isResultSet = showStatistics.getMoreResults(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return second result set"); + showStatistics.getResultSet().close(); + + isResultSet = showStatistics.getMoreResults(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return third result set"); + try (ResultSet resultSet = showStatistics.getResultSet()) { + while (resultSet.next()) { + resultSet.getObject("RANGE_HI_KEY"); + if (resultSet.wasNull()) { + // Null fraction + checkState(resultSet.getLong("RANGE_ROWS") == 0, "Unexpected RANGE_ROWS for null fraction"); + checkState(resultSet.getLong("DISTINCT_RANGE_ROWS") == 0, "Unexpected DISTINCT_RANGE_ROWS for null fraction"); + checkState(nullValues == 0, "Multiple null fraction entries"); + nullValues += resultSet.getLong("EQ_ROWS"); + } + else { + // TODO discover min/max from resultSet.getXxx("RANGE_HI_KEY") + notNullValues += resultSet.getLong("RANGE_ROWS") // rows strictly within a bucket + + resultSet.getLong("EQ_ROWS"); // rows equal to RANGE_HI_KEY + distinctValues += resultSet.getLong("DISTINCT_RANGE_ROWS") // NDV strictly within a bucket + + (resultSet.getLong("EQ_ROWS") > 0 ? 1 : 0); + } + } + } + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(Estimate.of( + (notNullValues + nullValues == 0) + ? 1 + : (1.0 * nullValues / (notNullValues + nullValues)))) + .setDistinctValuesCount(Estimate.of(distinctValues)) + .setDataSize(Estimate.of(rowCount * averageColumnLength)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Map getColumnNameToStatisticsName(JdbcTableHandle table, StatisticsDao statisticsDao, Long tableObjectId) + { + List singleColumnStatistics = statisticsDao.getSingleColumnStatistics(tableObjectId); + + Map columnNameToStatisticsName = new HashMap<>(); + for (String statisticName : singleColumnStatistics) { + String columnName = statisticsDao.getSingleColumnStatisticsColumnName(tableObjectId, statisticName); + if (columnName == null) { + // Table or statistics disappeared + continue; + } + + if (columnNameToStatisticsName.putIfAbsent(columnName, statisticName) != null) { + log.debug("Multiple statistics for %s in %s: %s and %s", columnName, table, columnNameToStatisticsName.get(columnName), statisticName); + } + } + return columnNameToStatisticsName; + } + private LongWriteFunction sqlServerTimeWriteFunction(int precision) { return new LongWriteFunction() @@ -833,4 +991,65 @@ private enum SnapshotIsolationEnabledCacheKey // database, so from our perspective, this is a global property. INSTANCE } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Long getTableObjectId(String catalog, String schema, String tableName) + { + return handle.createQuery("SELECT object_id(:table)") + .bind("table", format("%s.%s.%s", catalog, schema, tableName)) + .mapTo(Long.class) + .findOnly(); + } + + Long getRowCount(long tableObjectId) + { + return handle.createQuery("" + + "SELECT sum(rows) row_count " + + "FROM sys.partitions " + + "WHERE object_id = :object_id " + + "AND index_id IN (0, 1)") // 0 = heap, 1 = clustered index, 2 or greater = non-clustered index + .bind("object_id", tableObjectId) + .mapTo(Long.class) + .findOnly(); + } + + List getSingleColumnStatistics(long tableObjectId) + { + return handle.createQuery("" + + "SELECT s.name " + + "FROM sys.stats AS s " + + "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " + + "WHERE s.object_id = :object_id " + + "GROUP BY s.name " + + "HAVING count(*) = 1 " + + "ORDER BY s.name") + .bind("object_id", tableObjectId) + .mapTo(String.class) + .list(); + } + + String getSingleColumnStatisticsColumnName(long tableObjectId, String statisticsName) + { + return handle.createQuery("" + + "SELECT c.name " + + "FROM sys.stats AS s " + + "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " + + "JOIN sys.columns AS c ON sc.object_id = c.object_id AND c.column_id = sc.column_id " + + "WHERE s.object_id = :object_id " + + "AND s.name = :statistics_name") + .bind("object_id", tableObjectId) + .bind("statistics_name", statisticsName) + .mapTo(String.class) + .collect(toOptional()) // verify there is no more than 1 column name returned + .orElse(null); + } + } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java index 7e2791ff77cd..dd1fffaf62e0 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java @@ -25,6 +25,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -40,6 +41,7 @@ public class SqlServerClientModule public void configure(Binder binder) { configBinder(binder).bindConfig(SqlServerConfig.class); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON); bindTablePropertiesProvider(binder, SqlServerTableProperties.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS); diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java index 1e6d993f6e70..ac9ecb52b445 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java @@ -19,6 +19,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; @@ -59,6 +60,7 @@ public class TestSqlServerClient private static final JdbcClient JDBC_CLIENT = new SqlServerClient( new BaseJdbcConfig(), new SqlServerConfig(), + new JdbcStatisticsConfig(), session -> { throw new UnsupportedOperationException(); }, diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java new file mode 100644 index 000000000000..a3155461ece1 --- /dev/null +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java @@ -0,0 +1,415 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.lang.String.format; + +public class TestSqlServerTableStatistics + extends BaseJdbcTableStatisticsTest +{ + private TestingSqlServer sqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + sqlServer = closeAfterClass(new TestingSqlServer()); + return SqlServerQueryRunner.createSqlServerQueryRunner( + sqlServer, + Map.of(), + Map.of( + "case-insensitive-name-matching", "true", + "join-pushdown.enabled", "true"), + List.of(ORDERS)); + } + + @Override + @Test + public void testNotAnalyzed() + { + String tableName = "test_stats_not_analyzed"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testBasic() + { + String tableName = "test_stats_orders"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('clerk', 450000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + protected void checkEmptyTableStats(String tableName) + { + // TODO: Empty tables should have NDV as 0 and nulls fraction as 1 + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + } + + @Override + @Test + public void testAllNulls() + { + String tableName = "test_stats_table_all_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName)); + try { + computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL", tableName)); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', 0, 0, 1, null, null, null)," + + "('custkey', 0, 0, 1, null, null, null)," + + "('orderpriority', 0, 0, 1, null, null, null)," + + "('comment', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 3, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNullsFraction() + { + String tableName = "test_stats_table_with_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + assertUpdate("" + + "CREATE TABLE " + tableName + " AS " + + "SELECT " + + " orderkey, " + + " if(orderkey % 3 = 0, NULL, custkey) custkey, " + + " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " + + "FROM tpch.tiny.orders", + 15000); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0.3333333333333333, null, null, null)," + + "('orderpriority', 201914, 5, 0.2, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testAverageColumnLength() + { + String tableName = "test_stats_table_avg_col_len"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual("" + + "CREATE TABLE " + tableName + " AS SELECT " + + " orderkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " if(orderkey = 1, '0123456789', NULL) single_10v_value, " + + " if(orderkey % 2 = 0, '0123456789', NULL) half_10v_value, " + + " if(orderkey % 2 = 0, CAST((1000000 - orderkey) * (1000000 - orderkey) AS varchar(20)), NULL) half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM tpch.tiny.orders " + + "ORDER BY orderkey LIMIT 100"); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 100, 0, null, null, null)," + + "('v3_in_3', 600, 1, 0, null, null, null)," + + "('v3_in_42', 600, 1, 0, null, null, null)," + + "('single_10v_value', 20, 1, 0.99, null, null, null)," + + "('half_10v_value', 1000, 1, 0.5, null, null, null)," + + "('half_distinct_20v_value', 1200, 50, 0.5, null, null, null)," + + "('all_nulls', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 100, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testPartitionedTable() + { + throw new SkipException("Not implemented"); // TODO + } + + @Override + @Test + public void testView() + { + String tableName = "test_stats_view"; + sqlServer.execute("CREATE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders"); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + // It's not possible to ANALYZE a VIEW in SQL Server + } + finally { + sqlServer.execute("DROP VIEW " + tableName); + } + } + + @Override + public void testMaterializedView() + { + throw new SkipException("see testIndexedView"); + } + + @Test + public void testIndexedView() // materialized view + { + String tableName = "test_stats_indexed_view"; + // indexed views require fixed values for several SET options + try (Handle handle = Jdbi.open(sqlServer::createConnection)) { + // indexed views require fixed values for several SET options + handle.execute("SET NUMERIC_ROUNDABORT OFF"); + handle.execute("SET ANSI_PADDING, ANSI_WARNINGS, CONCAT_NULL_YIELDS_NULL, ARITHABORT, QUOTED_IDENTIFIER, ANSI_NULLS ON"); + + handle.execute("" + + "CREATE VIEW " + tableName + " " + + "WITH SCHEMABINDING " + + "AS SELECT orderkey, custkey, orderpriority, comment FROM dbo.orders"); + try { + handle.execute("CREATE UNIQUE CLUSTERED INDEX idx1 ON " + tableName + " (orderkey, custkey, orderpriority, comment)"); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + handle.execute("DROP VIEW " + tableName); + } + } + } + + @Override + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + sqlServer.execute("" + + "SELECT " + + " orderkey CASE_UNQUOTED_UPPER, " + + " custkey case_unquoted_lower, " + + " orderstatus cASe_uNQuoTeD_miXED, " + + " totalprice \"CASE_QUOTED_UPPER\", " + + " orderdate \"case_quoted_lower\", " + + " orderpriority \"CasE_QuoTeD_miXED\" " + + "INTO " + tableName + " " + + "FROM orders"); + try { + gatherStats( + tableName, + ImmutableList.of( + "CASE_UNQUOTED_UPPER", + "case_unquoted_lower", + "cASe_uNQuoTeD_miXED", + "CASE_QUOTED_UPPER", + "case_quoted_lower", + "CasE_QuoTeD_miXED")); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('case_unquoted_upper', null, 15000, 0, null, null, null)," + + "('case_unquoted_lower', null, 1000, 0, null, null, null)," + + "('case_unquoted_mixed', 30000, 3, 0, null, null, null)," + + "('case_quoted_upper', null, 14996, 0, null, null, null)," + + "('case_quoted_lower', null, 2401, 0, null, null, null)," + + "('case_quoted_mixed', 252376, 5, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + sqlServer.execute("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() +// TODO infinity and NaNs are not supported by SQLServer +// .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) +// .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) +// .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) +// .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) +// .put("nans_only double", List.of("nan()", "nan()")) +// .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + gatherStats(table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + +// TODO infinity and NaNs are not supported by SQLServer +// "('only_negative_infinity', null, 1, 0, null, null, null)," + +// "('only_positive_infinity', null, 1, 0, null, null, null)," + +// "('mixed_infinities', null, 2, 0, null, null, null)," + +// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + +// "('nans_only', null, 1.0, 0.5, null, null, null)," + +// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "(null, null, null, null, 2, null, null)"); + } + } + + @Test + public void testShowStatsAfterCreateIndex() + { + String tableName = "test_stats_create_index"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + + String expected = "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('clerk', 450000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"; + + try { + gatherStats(tableName); + assertQuery("SHOW STATS FOR " + tableName, expected); + + // CREATE INDEX statement updates sys.partitions table + sqlServer.execute(format("CREATE INDEX idx ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE UNIQUE INDEX unique_index ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE CLUSTERED INDEX clustered_index ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE NONCLUSTERED INDEX non_clustered_index ON %s (orderkey)", tableName)); + + assertQuery("SHOW STATS FOR " + tableName, expected); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + protected void gatherStats(String tableName) + { + List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName)) + .map(row -> (String) row.getField(0)) + .collect(toImmutableList()); + gatherStats(tableName, columnNames); + } + + private void gatherStats(String tableName, List columnNames) + { + for (Object columnName : columnNames) { + sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName)); + } + sqlServer.execute("UPDATE STATISTICS " + tableName); + } +} From 786b693b7bce8aa0ab4c866ad89b9c0506220baf Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 23 Mar 2022 15:07:38 +0100 Subject: [PATCH 3/3] Add automatic JOIN pushdown support to SQL Server connector --- .../plugin/sqlserver/SqlServerClient.java | 24 +++++++++ .../sqlserver/SqlServerClientModule.java | 8 +-- .../sqlserver/BaseSqlServerConnectorTest.java | 9 ++++ .../TestSqlServerAutomaticJoinPushdown.java | 54 +++++++++++++++++++ .../TestSqlServerTableStatistics.java | 4 +- 5 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 007752e084a1..60088a667596 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -43,6 +43,7 @@ import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.SliceWriteFunction; @@ -59,6 +60,8 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.statistics.ColumnStatistics; @@ -112,6 +115,7 @@ import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; @@ -610,6 +614,26 @@ private static Map getColumnNameToStatisticsName(JdbcTableHandle return columnNameToStatisticsName; } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + private LongWriteFunction sqlServerTimeWriteFunction(int precision) { return new LongWriteFunction() diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java index dd1fffaf62e0..5b3a93eb0881 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java @@ -15,16 +15,17 @@ import com.google.inject.Binder; import com.google.inject.Key; -import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.Singleton; import com.microsoft.sqlserver.jdbc.SQLServerDriver; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -35,16 +36,17 @@ import static io.trino.plugin.sqlserver.SqlServerClient.SQL_SERVER_MAX_LIST_EXPRESSIONS; public class SqlServerClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + protected void setup(Binder binder) { configBinder(binder).bindConfig(SqlServerConfig.class); configBinder(binder).bindConfig(JdbcStatisticsConfig.class); binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON); bindTablePropertiesProvider(binder, SqlServerTableProperties.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS); + install(new JdbcJoinPushdownSupportModule()); } @Provides diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index 3383d64424b6..24d9a6a31af0 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -539,4 +539,13 @@ private String getLongInClause(int start, int length) .collect(joining(", ")); return "orderkey IN (" + longValues + ")"; } + + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } } diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java new file mode 100644 index 000000000000..94c428668902 --- /dev/null +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.plugin.sqlserver.SqlServerQueryRunner.createSqlServerQueryRunner; +import static java.lang.String.format; + +public class TestSqlServerAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + private TestingSqlServer sqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + sqlServer = closeAfterClass(new TestingSqlServer()); + return createSqlServerQueryRunner(sqlServer, Map.of(), Map.of(), List.of()); + } + + @Override + protected void gatherStats(String tableName) + { + List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName)) + .map(row -> (String) row.getField(0)) + .map(columnName -> "\"" + columnName + "\"") + .collect(toImmutableList()); + + for (String columnName : columnNames) { + sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName)); + } + + sqlServer.execute("UPDATE STATISTICS " + tableName); + } +} diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java index a3155461ece1..b2ce0f624108 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java @@ -45,9 +45,7 @@ protected QueryRunner createQueryRunner() return SqlServerQueryRunner.createSqlServerQueryRunner( sqlServer, Map.of(), - Map.of( - "case-insensitive-name-matching", "true", - "join-pushdown.enabled", "true"), + Map.of("case-insensitive-name-matching", "true"), List.of(ORDERS)); }