diff --git a/plugin/trino-sqlserver/pom.xml b/plugin/trino-sqlserver/pom.xml index 74a29e6a6f9e..fba1fc188beb 100644 --- a/plugin/trino-sqlserver/pom.xml +++ b/plugin/trino-sqlserver/pom.xml @@ -15,6 +15,14 @@ ${project.parent.basedir} + + + classes diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index f663dd8e8945..60088a667596 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -36,12 +36,14 @@ import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSortItem; import io.trino.plugin.jdbc.JdbcSplit; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongReadFunction; import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.SliceWriteFunction; @@ -58,7 +60,13 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -77,6 +85,7 @@ import javax.inject.Inject; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -89,6 +98,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -97,11 +107,15 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.collect.MoreCollectors.toOptional; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT; import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; @@ -191,18 +205,27 @@ public class SqlServerClient .maximumSize(1) .expireAfterWrite(ofMinutes(5))); + private final boolean statisticsEnabled; + private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; private static final int MAX_SUPPORTED_TEMPORAL_PRECISION = 7; @Inject - public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public SqlServerClient( + BaseJdbcConfig config, + SqlServerConfig sqlServerConfig, + JdbcStatisticsConfig statisticsConfig, + ConnectionFactory connectionFactory, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping); requireNonNull(sqlServerConfig, "sqlServerConfig is null"); snapshotIsolationDisabled = sqlServerConfig.isSnapshotIsolationDisabled(); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) @@ -452,6 +475,165 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return readTableStatistics(session, handle); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + String catalog = table.getCatalogName(); + String schema = table.getSchemaName(); + String tableName = table.getTableName(); + + StatisticsDao statisticsDao = new StatisticsDao(handle); + Long tableObjectId = statisticsDao.getTableObjectId(catalog, schema, tableName); + if (tableObjectId == null) { + // Table not found + return TableStatistics.empty(); + } + + Long rowCount = statisticsDao.getRowCount(tableObjectId); + if (rowCount == null) { + // Table disappeared + return TableStatistics.empty(); + } + + if (rowCount == 0) { + return TableStatistics.empty(); + } + + TableStatistics.Builder tableStatistics = TableStatistics.builder(); + tableStatistics.setRowCount(Estimate.of(rowCount)); + + Map columnNameToStatisticsName = getColumnNameToStatisticsName(table, statisticsDao, tableObjectId); + + for (JdbcColumnHandle column : this.getColumns(session, table)) { + String statisticName = columnNameToStatisticsName.get(column.getColumnName()); + if (statisticName == null) { + // No statistic for column + continue; + } + + double averageColumnLength; + long notNullValues = 0; + long nullValues = 0; + long distinctValues = 0; + + try (CallableStatement showStatistics = handle.getConnection().prepareCall("DBCC SHOW_STATISTICS (?, ?)")) { + showStatistics.setString(1, format("%s.%s.%s", catalog, schema, tableName)); + showStatistics.setString(2, statisticName); + + boolean isResultSet = showStatistics.execute(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return a result set"); + try (ResultSet resultSet = showStatistics.getResultSet()) { + checkState(resultSet.next(), "No rows in result set"); + + averageColumnLength = resultSet.getDouble("Average Key Length"); // NULL values are accounted for with length 0 + + checkState(!resultSet.next(), "More than one row in result set"); + } + + isResultSet = showStatistics.getMoreResults(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return second result set"); + showStatistics.getResultSet().close(); + + isResultSet = showStatistics.getMoreResults(); + checkState(isResultSet, "Expected SHOW_STATISTICS to return third result set"); + try (ResultSet resultSet = showStatistics.getResultSet()) { + while (resultSet.next()) { + resultSet.getObject("RANGE_HI_KEY"); + if (resultSet.wasNull()) { + // Null fraction + checkState(resultSet.getLong("RANGE_ROWS") == 0, "Unexpected RANGE_ROWS for null fraction"); + checkState(resultSet.getLong("DISTINCT_RANGE_ROWS") == 0, "Unexpected DISTINCT_RANGE_ROWS for null fraction"); + checkState(nullValues == 0, "Multiple null fraction entries"); + nullValues += resultSet.getLong("EQ_ROWS"); + } + else { + // TODO discover min/max from resultSet.getXxx("RANGE_HI_KEY") + notNullValues += resultSet.getLong("RANGE_ROWS") // rows strictly within a bucket + + resultSet.getLong("EQ_ROWS"); // rows equal to RANGE_HI_KEY + distinctValues += resultSet.getLong("DISTINCT_RANGE_ROWS") // NDV strictly within a bucket + + (resultSet.getLong("EQ_ROWS") > 0 ? 1 : 0); + } + } + } + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(Estimate.of( + (notNullValues + nullValues == 0) + ? 1 + : (1.0 * nullValues / (notNullValues + nullValues)))) + .setDistinctValuesCount(Estimate.of(distinctValues)) + .setDataSize(Estimate.of(rowCount * averageColumnLength)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Map getColumnNameToStatisticsName(JdbcTableHandle table, StatisticsDao statisticsDao, Long tableObjectId) + { + List singleColumnStatistics = statisticsDao.getSingleColumnStatistics(tableObjectId); + + Map columnNameToStatisticsName = new HashMap<>(); + for (String statisticName : singleColumnStatistics) { + String columnName = statisticsDao.getSingleColumnStatisticsColumnName(tableObjectId, statisticName); + if (columnName == null) { + // Table or statistics disappeared + continue; + } + + if (columnNameToStatisticsName.putIfAbsent(columnName, statisticName) != null) { + log.debug("Multiple statistics for %s in %s: %s and %s", columnName, table, columnNameToStatisticsName.get(columnName), statisticName); + } + } + return columnNameToStatisticsName; + } + + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + private LongWriteFunction sqlServerTimeWriteFunction(int precision) { return new LongWriteFunction() @@ -833,4 +1015,65 @@ private enum SnapshotIsolationEnabledCacheKey // database, so from our perspective, this is a global property. INSTANCE } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Long getTableObjectId(String catalog, String schema, String tableName) + { + return handle.createQuery("SELECT object_id(:table)") + .bind("table", format("%s.%s.%s", catalog, schema, tableName)) + .mapTo(Long.class) + .findOnly(); + } + + Long getRowCount(long tableObjectId) + { + return handle.createQuery("" + + "SELECT sum(rows) row_count " + + "FROM sys.partitions " + + "WHERE object_id = :object_id " + + "AND index_id IN (0, 1)") // 0 = heap, 1 = clustered index, 2 or greater = non-clustered index + .bind("object_id", tableObjectId) + .mapTo(Long.class) + .findOnly(); + } + + List getSingleColumnStatistics(long tableObjectId) + { + return handle.createQuery("" + + "SELECT s.name " + + "FROM sys.stats AS s " + + "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " + + "WHERE s.object_id = :object_id " + + "GROUP BY s.name " + + "HAVING count(*) = 1 " + + "ORDER BY s.name") + .bind("object_id", tableObjectId) + .mapTo(String.class) + .list(); + } + + String getSingleColumnStatisticsColumnName(long tableObjectId, String statisticsName) + { + return handle.createQuery("" + + "SELECT c.name " + + "FROM sys.stats AS s " + + "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " + + "JOIN sys.columns AS c ON sc.object_id = c.object_id AND c.column_id = sc.column_id " + + "WHERE s.object_id = :object_id " + + "AND s.name = :statistics_name") + .bind("object_id", tableObjectId) + .bind("statistics_name", statisticsName) + .mapTo(String.class) + .collect(toOptional()) // verify there is no more than 1 column name returned + .orElse(null); + } + } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java index 7e2791ff77cd..5b3a93eb0881 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java @@ -15,16 +15,18 @@ import com.google.inject.Binder; import com.google.inject.Key; -import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.Singleton; import com.microsoft.sqlserver.jdbc.SQLServerDriver; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -34,15 +36,17 @@ import static io.trino.plugin.sqlserver.SqlServerClient.SQL_SERVER_MAX_LIST_EXPRESSIONS; public class SqlServerClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + protected void setup(Binder binder) { configBinder(binder).bindConfig(SqlServerConfig.class); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON); bindTablePropertiesProvider(binder, SqlServerTableProperties.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS); + install(new JdbcJoinPushdownSupportModule()); } @Provides diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index 3383d64424b6..24d9a6a31af0 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -539,4 +539,13 @@ private String getLongInClause(int start, int length) .collect(joining(", ")); return "orderkey IN (" + longValues + ")"; } + + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } } diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java new file mode 100644 index 000000000000..94c428668902 --- /dev/null +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.plugin.sqlserver.SqlServerQueryRunner.createSqlServerQueryRunner; +import static java.lang.String.format; + +public class TestSqlServerAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + private TestingSqlServer sqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + sqlServer = closeAfterClass(new TestingSqlServer()); + return createSqlServerQueryRunner(sqlServer, Map.of(), Map.of(), List.of()); + } + + @Override + protected void gatherStats(String tableName) + { + List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName)) + .map(row -> (String) row.getField(0)) + .map(columnName -> "\"" + columnName + "\"") + .collect(toImmutableList()); + + for (String columnName : columnNames) { + sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName)); + } + + sqlServer.execute("UPDATE STATISTICS " + tableName); + } +} diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java index 1e6d993f6e70..ac9ecb52b445 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java @@ -19,6 +19,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; @@ -59,6 +60,7 @@ public class TestSqlServerClient private static final JdbcClient JDBC_CLIENT = new SqlServerClient( new BaseJdbcConfig(), new SqlServerConfig(), + new JdbcStatisticsConfig(), session -> { throw new UnsupportedOperationException(); }, diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java new file mode 100644 index 000000000000..b2ce0f624108 --- /dev/null +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java @@ -0,0 +1,413 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.lang.String.format; + +public class TestSqlServerTableStatistics + extends BaseJdbcTableStatisticsTest +{ + private TestingSqlServer sqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + sqlServer = closeAfterClass(new TestingSqlServer()); + return SqlServerQueryRunner.createSqlServerQueryRunner( + sqlServer, + Map.of(), + Map.of("case-insensitive-name-matching", "true"), + List.of(ORDERS)); + } + + @Override + @Test + public void testNotAnalyzed() + { + String tableName = "test_stats_not_analyzed"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderstatus', null, null, null, null, null, null)," + + "('totalprice', null, null, null, null, null, null)," + + "('orderdate', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('clerk', null, null, null, null, null, null)," + + "('shippriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testBasic() + { + String tableName = "test_stats_orders"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('clerk', 450000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + protected void checkEmptyTableStats(String tableName) + { + // TODO: Empty tables should have NDV as 0 and nulls fraction as 1 + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + } + + @Override + @Test + public void testAllNulls() + { + String tableName = "test_stats_table_all_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName)); + try { + computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL", tableName)); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', 0, 0, 1, null, null, null)," + + "('custkey', 0, 0, 1, null, null, null)," + + "('orderpriority', 0, 0, 1, null, null, null)," + + "('comment', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 3, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNullsFraction() + { + String tableName = "test_stats_table_with_nulls"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + assertUpdate("" + + "CREATE TABLE " + tableName + " AS " + + "SELECT " + + " orderkey, " + + " if(orderkey % 3 = 0, NULL, custkey) custkey, " + + " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " + + "FROM tpch.tiny.orders", + 15000); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0.3333333333333333, null, null, null)," + + "('orderpriority', 201914, 5, 0.2, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testAverageColumnLength() + { + String tableName = "test_stats_table_avg_col_len"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual("" + + "CREATE TABLE " + tableName + " AS SELECT " + + " orderkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " if(orderkey = 1, '0123456789', NULL) single_10v_value, " + + " if(orderkey % 2 = 0, '0123456789', NULL) half_10v_value, " + + " if(orderkey % 2 = 0, CAST((1000000 - orderkey) * (1000000 - orderkey) AS varchar(20)), NULL) half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM tpch.tiny.orders " + + "ORDER BY orderkey LIMIT 100"); + try { + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 100, 0, null, null, null)," + + "('v3_in_3', 600, 1, 0, null, null, null)," + + "('v3_in_42', 600, 1, 0, null, null, null)," + + "('single_10v_value', 20, 1, 0.99, null, null, null)," + + "('half_10v_value', 1000, 1, 0.5, null, null, null)," + + "('half_distinct_20v_value', 1200, 50, 0.5, null, null, null)," + + "('all_nulls', 0, 0, 1, null, null, null)," + + "(null, null, null, null, 100, null, null)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testPartitionedTable() + { + throw new SkipException("Not implemented"); // TODO + } + + @Override + @Test + public void testView() + { + String tableName = "test_stats_view"; + sqlServer.execute("CREATE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders"); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + // It's not possible to ANALYZE a VIEW in SQL Server + } + finally { + sqlServer.execute("DROP VIEW " + tableName); + } + } + + @Override + public void testMaterializedView() + { + throw new SkipException("see testIndexedView"); + } + + @Test + public void testIndexedView() // materialized view + { + String tableName = "test_stats_indexed_view"; + // indexed views require fixed values for several SET options + try (Handle handle = Jdbi.open(sqlServer::createConnection)) { + // indexed views require fixed values for several SET options + handle.execute("SET NUMERIC_ROUNDABORT OFF"); + handle.execute("SET ANSI_PADDING, ANSI_WARNINGS, CONCAT_NULL_YIELDS_NULL, ARITHABORT, QUOTED_IDENTIFIER, ANSI_NULLS ON"); + + handle.execute("" + + "CREATE VIEW " + tableName + " " + + "WITH SCHEMABINDING " + + "AS SELECT orderkey, custkey, orderpriority, comment FROM dbo.orders"); + try { + handle.execute("CREATE UNIQUE CLUSTERED INDEX idx1 ON " + tableName + " (orderkey, custkey, orderpriority, comment)"); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + handle.execute("DROP VIEW " + tableName); + } + } + } + + @Override + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + sqlServer.execute("" + + "SELECT " + + " orderkey CASE_UNQUOTED_UPPER, " + + " custkey case_unquoted_lower, " + + " orderstatus cASe_uNQuoTeD_miXED, " + + " totalprice \"CASE_QUOTED_UPPER\", " + + " orderdate \"case_quoted_lower\", " + + " orderpriority \"CasE_QuoTeD_miXED\" " + + "INTO " + tableName + " " + + "FROM orders"); + try { + gatherStats( + tableName, + ImmutableList.of( + "CASE_UNQUOTED_UPPER", + "case_unquoted_lower", + "cASe_uNQuoTeD_miXED", + "CASE_QUOTED_UPPER", + "case_quoted_lower", + "CasE_QuoTeD_miXED")); + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('case_unquoted_upper', null, 15000, 0, null, null, null)," + + "('case_unquoted_lower', null, 1000, 0, null, null, null)," + + "('case_unquoted_mixed', 30000, 3, 0, null, null, null)," + + "('case_quoted_upper', null, 14996, 0, null, null, null)," + + "('case_quoted_lower', null, 2401, 0, null, null, null)," + + "('case_quoted_mixed', 252376, 5, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"); + } + finally { + sqlServer.execute("DROP TABLE " + tableName); + } + } + + @Override + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() +// TODO infinity and NaNs are not supported by SQLServer +// .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) +// .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) +// .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) +// .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) +// .put("nans_only double", List.of("nan()", "nan()")) +// .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + gatherStats(table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + +// TODO infinity and NaNs are not supported by SQLServer +// "('only_negative_infinity', null, 1, 0, null, null, null)," + +// "('only_positive_infinity', null, 1, 0, null, null, null)," + +// "('mixed_infinities', null, 2, 0, null, null, null)," + +// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + +// "('nans_only', null, 1.0, 0.5, null, null, null)," + +// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "(null, null, null, null, 2, null, null)"); + } + } + + @Test + public void testShowStatsAfterCreateIndex() + { + String tableName = "test_stats_create_index"; + assertUpdate("DROP TABLE IF EXISTS " + tableName); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + + String expected = "VALUES " + + "('orderkey', null, 15000, 0, null, null, null)," + + "('custkey', null, 1000, 0, null, null, null)," + + "('orderstatus', 30000, 3, 0, null, null, null)," + + "('totalprice', null, 14996, 0, null, null, null)," + + "('orderdate', null, 2401, 0, null, null, null)," + + "('orderpriority', 252376, 5, 0, null, null, null)," + + "('clerk', 450000, 1000, 0, null, null, null)," + + "('shippriority', null, 1, 0, null, null, null)," + + "('comment', 1454727, 14994, 0, null, null, null)," + + "(null, null, null, null, 15000, null, null)"; + + try { + gatherStats(tableName); + assertQuery("SHOW STATS FOR " + tableName, expected); + + // CREATE INDEX statement updates sys.partitions table + sqlServer.execute(format("CREATE INDEX idx ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE UNIQUE INDEX unique_index ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE CLUSTERED INDEX clustered_index ON %s (orderkey)", tableName)); + sqlServer.execute(format("CREATE NONCLUSTERED INDEX non_clustered_index ON %s (orderkey)", tableName)); + + assertQuery("SHOW STATS FOR " + tableName, expected); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Override + protected void gatherStats(String tableName) + { + List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName)) + .map(row -> (String) row.getField(0)) + .collect(toImmutableList()); + gatherStats(tableName, columnNames); + } + + private void gatherStats(String tableName, List columnNames) + { + for (Object columnName : columnNames) { + sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName)); + } + sqlServer.execute("UPDATE STATISTICS " + tableName); + } +}