diff --git a/plugin/trino-sqlserver/pom.xml b/plugin/trino-sqlserver/pom.xml
index 74a29e6a6f9e..fba1fc188beb 100644
--- a/plugin/trino-sqlserver/pom.xml
+++ b/plugin/trino-sqlserver/pom.xml
@@ -15,6 +15,14 @@
${project.parent.basedir}
+
+
+ classes
diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java
index f663dd8e8945..60088a667596 100644
--- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java
+++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java
@@ -36,12 +36,14 @@
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcSortItem;
import io.trino.plugin.jdbc.JdbcSplit;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongReadFunction;
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.ObjectReadFunction;
import io.trino.plugin.jdbc.ObjectWriteFunction;
+import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
@@ -58,7 +60,13 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinCondition;
+import io.trino.spi.connector.JoinStatistics;
+import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
+import io.trino.spi.predicate.TupleDomain;
+import io.trino.spi.statistics.ColumnStatistics;
+import io.trino.spi.statistics.Estimate;
+import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
@@ -77,6 +85,7 @@
import javax.inject.Inject;
+import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
@@ -89,6 +98,7 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -97,11 +107,15 @@
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Throwables.throwIfInstanceOf;
+import static com.google.common.collect.MoreCollectors.toOptional;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
+import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware;
import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction;
@@ -191,18 +205,27 @@ public class SqlServerClient
.maximumSize(1)
.expireAfterWrite(ofMinutes(5)));
+ private final boolean statisticsEnabled;
+
private final ConnectorExpressionRewriter connectorExpressionRewriter;
private final AggregateFunctionRewriter aggregateFunctionRewriter;
private static final int MAX_SUPPORTED_TEMPORAL_PRECISION = 7;
@Inject
- public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping)
+ public SqlServerClient(
+ BaseJdbcConfig config,
+ SqlServerConfig sqlServerConfig,
+ JdbcStatisticsConfig statisticsConfig,
+ ConnectionFactory connectionFactory,
+ QueryBuilder queryBuilder,
+ IdentifierMapping identifierMapping)
{
super(config, "\"", connectionFactory, queryBuilder, identifierMapping);
requireNonNull(sqlServerConfig, "sqlServerConfig is null");
snapshotIsolationDisabled = sqlServerConfig.isSnapshotIsolationDisabled();
+ this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled();
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
@@ -452,6 +475,165 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
}
+ @Override
+ public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain)
+ {
+ if (!statisticsEnabled) {
+ return TableStatistics.empty();
+ }
+ if (!handle.isNamedRelation()) {
+ return TableStatistics.empty();
+ }
+ try {
+ return readTableStatistics(session, handle);
+ }
+ catch (SQLException | RuntimeException e) {
+ throwIfInstanceOf(e, TrinoException.class);
+ throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e);
+ }
+ }
+
+ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table)
+ throws SQLException
+ {
+ checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table);
+
+ try (Connection connection = connectionFactory.openConnection(session);
+ Handle handle = Jdbi.open(connection)) {
+ String catalog = table.getCatalogName();
+ String schema = table.getSchemaName();
+ String tableName = table.getTableName();
+
+ StatisticsDao statisticsDao = new StatisticsDao(handle);
+ Long tableObjectId = statisticsDao.getTableObjectId(catalog, schema, tableName);
+ if (tableObjectId == null) {
+ // Table not found
+ return TableStatistics.empty();
+ }
+
+ Long rowCount = statisticsDao.getRowCount(tableObjectId);
+ if (rowCount == null) {
+ // Table disappeared
+ return TableStatistics.empty();
+ }
+
+ if (rowCount == 0) {
+ return TableStatistics.empty();
+ }
+
+ TableStatistics.Builder tableStatistics = TableStatistics.builder();
+ tableStatistics.setRowCount(Estimate.of(rowCount));
+
+ Map columnNameToStatisticsName = getColumnNameToStatisticsName(table, statisticsDao, tableObjectId);
+
+ for (JdbcColumnHandle column : this.getColumns(session, table)) {
+ String statisticName = columnNameToStatisticsName.get(column.getColumnName());
+ if (statisticName == null) {
+ // No statistic for column
+ continue;
+ }
+
+ double averageColumnLength;
+ long notNullValues = 0;
+ long nullValues = 0;
+ long distinctValues = 0;
+
+ try (CallableStatement showStatistics = handle.getConnection().prepareCall("DBCC SHOW_STATISTICS (?, ?)")) {
+ showStatistics.setString(1, format("%s.%s.%s", catalog, schema, tableName));
+ showStatistics.setString(2, statisticName);
+
+ boolean isResultSet = showStatistics.execute();
+ checkState(isResultSet, "Expected SHOW_STATISTICS to return a result set");
+ try (ResultSet resultSet = showStatistics.getResultSet()) {
+ checkState(resultSet.next(), "No rows in result set");
+
+ averageColumnLength = resultSet.getDouble("Average Key Length"); // NULL values are accounted for with length 0
+
+ checkState(!resultSet.next(), "More than one row in result set");
+ }
+
+ isResultSet = showStatistics.getMoreResults();
+ checkState(isResultSet, "Expected SHOW_STATISTICS to return second result set");
+ showStatistics.getResultSet().close();
+
+ isResultSet = showStatistics.getMoreResults();
+ checkState(isResultSet, "Expected SHOW_STATISTICS to return third result set");
+ try (ResultSet resultSet = showStatistics.getResultSet()) {
+ while (resultSet.next()) {
+ resultSet.getObject("RANGE_HI_KEY");
+ if (resultSet.wasNull()) {
+ // Null fraction
+ checkState(resultSet.getLong("RANGE_ROWS") == 0, "Unexpected RANGE_ROWS for null fraction");
+ checkState(resultSet.getLong("DISTINCT_RANGE_ROWS") == 0, "Unexpected DISTINCT_RANGE_ROWS for null fraction");
+ checkState(nullValues == 0, "Multiple null fraction entries");
+ nullValues += resultSet.getLong("EQ_ROWS");
+ }
+ else {
+ // TODO discover min/max from resultSet.getXxx("RANGE_HI_KEY")
+ notNullValues += resultSet.getLong("RANGE_ROWS") // rows strictly within a bucket
+ + resultSet.getLong("EQ_ROWS"); // rows equal to RANGE_HI_KEY
+ distinctValues += resultSet.getLong("DISTINCT_RANGE_ROWS") // NDV strictly within a bucket
+ + (resultSet.getLong("EQ_ROWS") > 0 ? 1 : 0);
+ }
+ }
+ }
+ }
+
+ ColumnStatistics statistics = ColumnStatistics.builder()
+ .setNullsFraction(Estimate.of(
+ (notNullValues + nullValues == 0)
+ ? 1
+ : (1.0 * nullValues / (notNullValues + nullValues))))
+ .setDistinctValuesCount(Estimate.of(distinctValues))
+ .setDataSize(Estimate.of(rowCount * averageColumnLength))
+ .build();
+
+ tableStatistics.setColumnStatistics(column, statistics);
+ }
+
+ return tableStatistics.build();
+ }
+ }
+
+ private static Map getColumnNameToStatisticsName(JdbcTableHandle table, StatisticsDao statisticsDao, Long tableObjectId)
+ {
+ List singleColumnStatistics = statisticsDao.getSingleColumnStatistics(tableObjectId);
+
+ Map columnNameToStatisticsName = new HashMap<>();
+ for (String statisticName : singleColumnStatistics) {
+ String columnName = statisticsDao.getSingleColumnStatisticsColumnName(tableObjectId, statisticName);
+ if (columnName == null) {
+ // Table or statistics disappeared
+ continue;
+ }
+
+ if (columnNameToStatisticsName.putIfAbsent(columnName, statisticName) != null) {
+ log.debug("Multiple statistics for %s in %s: %s and %s", columnName, table, columnNameToStatisticsName.get(columnName), statisticName);
+ }
+ }
+ return columnNameToStatisticsName;
+ }
+
+ @Override
+ public Optional implementJoin(
+ ConnectorSession session,
+ JoinType joinType,
+ PreparedQuery leftSource,
+ PreparedQuery rightSource,
+ List joinConditions,
+ Map rightAssignments,
+ Map leftAssignments,
+ JoinStatistics statistics)
+ {
+ return implementJoinCostAware(
+ session,
+ joinType,
+ leftSource,
+ rightSource,
+ statistics,
+ () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics));
+ }
+
private LongWriteFunction sqlServerTimeWriteFunction(int precision)
{
return new LongWriteFunction()
@@ -833,4 +1015,65 @@ private enum SnapshotIsolationEnabledCacheKey
// database, so from our perspective, this is a global property.
INSTANCE
}
+
+ private static class StatisticsDao
+ {
+ private final Handle handle;
+
+ public StatisticsDao(Handle handle)
+ {
+ this.handle = requireNonNull(handle, "handle is null");
+ }
+
+ Long getTableObjectId(String catalog, String schema, String tableName)
+ {
+ return handle.createQuery("SELECT object_id(:table)")
+ .bind("table", format("%s.%s.%s", catalog, schema, tableName))
+ .mapTo(Long.class)
+ .findOnly();
+ }
+
+ Long getRowCount(long tableObjectId)
+ {
+ return handle.createQuery("" +
+ "SELECT sum(rows) row_count " +
+ "FROM sys.partitions " +
+ "WHERE object_id = :object_id " +
+ "AND index_id IN (0, 1)") // 0 = heap, 1 = clustered index, 2 or greater = non-clustered index
+ .bind("object_id", tableObjectId)
+ .mapTo(Long.class)
+ .findOnly();
+ }
+
+ List getSingleColumnStatistics(long tableObjectId)
+ {
+ return handle.createQuery("" +
+ "SELECT s.name " +
+ "FROM sys.stats AS s " +
+ "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " +
+ "WHERE s.object_id = :object_id " +
+ "GROUP BY s.name " +
+ "HAVING count(*) = 1 " +
+ "ORDER BY s.name")
+ .bind("object_id", tableObjectId)
+ .mapTo(String.class)
+ .list();
+ }
+
+ String getSingleColumnStatisticsColumnName(long tableObjectId, String statisticsName)
+ {
+ return handle.createQuery("" +
+ "SELECT c.name " +
+ "FROM sys.stats AS s " +
+ "JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " +
+ "JOIN sys.columns AS c ON sc.object_id = c.object_id AND c.column_id = sc.column_id " +
+ "WHERE s.object_id = :object_id " +
+ "AND s.name = :statistics_name")
+ .bind("object_id", tableObjectId)
+ .bind("statistics_name", statisticsName)
+ .mapTo(String.class)
+ .collect(toOptional()) // verify there is no more than 1 column name returned
+ .orElse(null);
+ }
+ }
}
diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java
index 7e2791ff77cd..5b3a93eb0881 100644
--- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java
+++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java
@@ -15,16 +15,18 @@
import com.google.inject.Binder;
import com.google.inject.Key;
-import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.microsoft.sqlserver.jdbc.SQLServerDriver;
+import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
+import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.MaxDomainCompactionThreshold;
import io.trino.plugin.jdbc.credential.CredentialProvider;
@@ -34,15 +36,17 @@
import static io.trino.plugin.sqlserver.SqlServerClient.SQL_SERVER_MAX_LIST_EXPRESSIONS;
public class SqlServerClientModule
- implements Module
+ extends AbstractConfigurationAwareModule
{
@Override
- public void configure(Binder binder)
+ protected void setup(Binder binder)
{
configBinder(binder).bindConfig(SqlServerConfig.class);
+ configBinder(binder).bindConfig(JdbcStatisticsConfig.class);
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON);
bindTablePropertiesProvider(binder, SqlServerTableProperties.class);
newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS);
+ install(new JdbcJoinPushdownSupportModule());
}
@Provides
diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java
index 3383d64424b6..24d9a6a31af0 100644
--- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java
+++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java
@@ -539,4 +539,13 @@ private String getLongInClause(int start, int length)
.collect(joining(", "));
return "orderkey IN (" + longValues + ")";
}
+
+ @Override
+ protected Session joinPushdownEnabled(Session session)
+ {
+ return Session.builder(super.joinPushdownEnabled(session))
+ // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected)
+ .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER")
+ .build();
+ }
}
diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java
new file mode 100644
index 000000000000..94c428668902
--- /dev/null
+++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerAutomaticJoinPushdown.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.sqlserver;
+
+import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest;
+import io.trino.testing.QueryRunner;
+
+import java.util.List;
+import java.util.Map;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.Streams.stream;
+import static io.trino.plugin.sqlserver.SqlServerQueryRunner.createSqlServerQueryRunner;
+import static java.lang.String.format;
+
+public class TestSqlServerAutomaticJoinPushdown
+ extends BaseAutomaticJoinPushdownTest
+{
+ private TestingSqlServer sqlServer;
+
+ @Override
+ protected QueryRunner createQueryRunner()
+ throws Exception
+ {
+ sqlServer = closeAfterClass(new TestingSqlServer());
+ return createSqlServerQueryRunner(sqlServer, Map.of(), Map.of(), List.of());
+ }
+
+ @Override
+ protected void gatherStats(String tableName)
+ {
+ List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName))
+ .map(row -> (String) row.getField(0))
+ .map(columnName -> "\"" + columnName + "\"")
+ .collect(toImmutableList());
+
+ for (String columnName : columnNames) {
+ sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName));
+ }
+
+ sqlServer.execute("UPDATE STATISTICS " + tableName);
+ }
+}
diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java
index 1e6d993f6e70..ac9ecb52b445 100644
--- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java
+++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java
@@ -19,6 +19,7 @@
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping;
import io.trino.spi.connector.AggregateFunction;
@@ -59,6 +60,7 @@ public class TestSqlServerClient
private static final JdbcClient JDBC_CLIENT = new SqlServerClient(
new BaseJdbcConfig(),
new SqlServerConfig(),
+ new JdbcStatisticsConfig(),
session -> {
throw new UnsupportedOperationException();
},
diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java
new file mode 100644
index 000000000000..b2ce0f624108
--- /dev/null
+++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerTableStatistics.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.sqlserver;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest;
+import io.trino.testing.QueryRunner;
+import io.trino.testing.sql.TestTable;
+import org.jdbi.v3.core.Handle;
+import org.jdbi.v3.core.Jdbi;
+import org.testng.SkipException;
+import org.testng.annotations.Test;
+
+import java.util.List;
+import java.util.Map;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.Streams.stream;
+import static io.trino.testing.sql.TestTable.fromColumns;
+import static io.trino.tpch.TpchTable.ORDERS;
+import static java.lang.String.format;
+
+public class TestSqlServerTableStatistics
+ extends BaseJdbcTableStatisticsTest
+{
+ private TestingSqlServer sqlServer;
+
+ @Override
+ protected QueryRunner createQueryRunner()
+ throws Exception
+ {
+ sqlServer = closeAfterClass(new TestingSqlServer());
+ return SqlServerQueryRunner.createSqlServerQueryRunner(
+ sqlServer,
+ Map.of(),
+ Map.of("case-insensitive-name-matching", "true"),
+ List.of(ORDERS));
+ }
+
+ @Override
+ @Test
+ public void testNotAnalyzed()
+ {
+ String tableName = "test_stats_not_analyzed";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName));
+ try {
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, null, null, null, null, null)," +
+ "('custkey', null, null, null, null, null, null)," +
+ "('orderstatus', null, null, null, null, null, null)," +
+ "('totalprice', null, null, null, null, null, null)," +
+ "('orderdate', null, null, null, null, null, null)," +
+ "('orderpriority', null, null, null, null, null, null)," +
+ "('clerk', null, null, null, null, null, null)," +
+ "('shippriority', null, null, null, null, null, null)," +
+ "('comment', null, null, null, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)");
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ @Test
+ public void testBasic()
+ {
+ String tableName = "test_stats_orders";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName));
+ try {
+ gatherStats(tableName);
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, 15000, 0, null, null, null)," +
+ "('custkey', null, 1000, 0, null, null, null)," +
+ "('orderstatus', 30000, 3, 0, null, null, null)," +
+ "('totalprice', null, 14996, 0, null, null, null)," +
+ "('orderdate', null, 2401, 0, null, null, null)," +
+ "('orderpriority', 252376, 5, 0, null, null, null)," +
+ "('clerk', 450000, 1000, 0, null, null, null)," +
+ "('shippriority', null, 1, 0, null, null, null)," +
+ "('comment', 1454727, 14994, 0, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)");
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ protected void checkEmptyTableStats(String tableName)
+ {
+ // TODO: Empty tables should have NDV as 0 and nulls fraction as 1
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, null, null, null, null, null)," +
+ "('custkey', null, null, null, null, null, null)," +
+ "('orderpriority', null, null, null, null, null, null)," +
+ "('comment', null, null, null, null, null, null)," +
+ "(null, null, null, null, null, null, null)");
+ }
+
+ @Override
+ @Test
+ public void testAllNulls()
+ {
+ String tableName = "test_stats_table_all_nulls";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName));
+ try {
+ computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL", tableName));
+ gatherStats(tableName);
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', 0, 0, 1, null, null, null)," +
+ "('custkey', 0, 0, 1, null, null, null)," +
+ "('orderpriority', 0, 0, 1, null, null, null)," +
+ "('comment', 0, 0, 1, null, null, null)," +
+ "(null, null, null, null, 3, null, null)");
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ @Test
+ public void testNullsFraction()
+ {
+ String tableName = "test_stats_table_with_nulls";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ assertUpdate("" +
+ "CREATE TABLE " + tableName + " AS " +
+ "SELECT " +
+ " orderkey, " +
+ " if(orderkey % 3 = 0, NULL, custkey) custkey, " +
+ " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " +
+ "FROM tpch.tiny.orders",
+ 15000);
+ try {
+ gatherStats(tableName);
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, 15000, 0, null, null, null)," +
+ "('custkey', null, 1000, 0.3333333333333333, null, null, null)," +
+ "('orderpriority', 201914, 5, 0.2, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)");
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ @Test
+ public void testAverageColumnLength()
+ {
+ String tableName = "test_stats_table_avg_col_len";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ computeActual("" +
+ "CREATE TABLE " + tableName + " AS SELECT " +
+ " orderkey, " +
+ " 'abc' v3_in_3, " +
+ " CAST('abc' AS varchar(42)) v3_in_42, " +
+ " if(orderkey = 1, '0123456789', NULL) single_10v_value, " +
+ " if(orderkey % 2 = 0, '0123456789', NULL) half_10v_value, " +
+ " if(orderkey % 2 = 0, CAST((1000000 - orderkey) * (1000000 - orderkey) AS varchar(20)), NULL) half_distinct_20v_value, " + // 12 chars each
+ " CAST(NULL AS varchar(10)) all_nulls " +
+ "FROM tpch.tiny.orders " +
+ "ORDER BY orderkey LIMIT 100");
+ try {
+ gatherStats(tableName);
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, 100, 0, null, null, null)," +
+ "('v3_in_3', 600, 1, 0, null, null, null)," +
+ "('v3_in_42', 600, 1, 0, null, null, null)," +
+ "('single_10v_value', 20, 1, 0.99, null, null, null)," +
+ "('half_10v_value', 1000, 1, 0.5, null, null, null)," +
+ "('half_distinct_20v_value', 1200, 50, 0.5, null, null, null)," +
+ "('all_nulls', 0, 0, 1, null, null, null)," +
+ "(null, null, null, null, 100, null, null)");
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ @Test
+ public void testPartitionedTable()
+ {
+ throw new SkipException("Not implemented"); // TODO
+ }
+
+ @Override
+ @Test
+ public void testView()
+ {
+ String tableName = "test_stats_view";
+ sqlServer.execute("CREATE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders");
+ try {
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, null, null, null, null, null)," +
+ "('custkey', null, null, null, null, null, null)," +
+ "('orderpriority', null, null, null, null, null, null)," +
+ "('comment', null, null, null, null, null, null)," +
+ "(null, null, null, null, null, null, null)");
+ // It's not possible to ANALYZE a VIEW in SQL Server
+ }
+ finally {
+ sqlServer.execute("DROP VIEW " + tableName);
+ }
+ }
+
+ @Override
+ public void testMaterializedView()
+ {
+ throw new SkipException("see testIndexedView");
+ }
+
+ @Test
+ public void testIndexedView() // materialized view
+ {
+ String tableName = "test_stats_indexed_view";
+ // indexed views require fixed values for several SET options
+ try (Handle handle = Jdbi.open(sqlServer::createConnection)) {
+ // indexed views require fixed values for several SET options
+ handle.execute("SET NUMERIC_ROUNDABORT OFF");
+ handle.execute("SET ANSI_PADDING, ANSI_WARNINGS, CONCAT_NULL_YIELDS_NULL, ARITHABORT, QUOTED_IDENTIFIER, ANSI_NULLS ON");
+
+ handle.execute("" +
+ "CREATE VIEW " + tableName + " " +
+ "WITH SCHEMABINDING " +
+ "AS SELECT orderkey, custkey, orderpriority, comment FROM dbo.orders");
+ try {
+ handle.execute("CREATE UNIQUE CLUSTERED INDEX idx1 ON " + tableName + " (orderkey, custkey, orderpriority, comment)");
+ gatherStats(tableName);
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, 15000, 0, null, null, null)," +
+ "('custkey', null, 1000, 0, null, null, null)," +
+ "('orderpriority', 252376, 5, 0, null, null, null)," +
+ "('comment', 1454727, 14994, 0, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)");
+ }
+ finally {
+ handle.execute("DROP VIEW " + tableName);
+ }
+ }
+ }
+
+ @Override
+ @Test(dataProvider = "testCaseColumnNamesDataProvider")
+ public void testCaseColumnNames(String tableName)
+ {
+ sqlServer.execute("" +
+ "SELECT " +
+ " orderkey CASE_UNQUOTED_UPPER, " +
+ " custkey case_unquoted_lower, " +
+ " orderstatus cASe_uNQuoTeD_miXED, " +
+ " totalprice \"CASE_QUOTED_UPPER\", " +
+ " orderdate \"case_quoted_lower\", " +
+ " orderpriority \"CasE_QuoTeD_miXED\" " +
+ "INTO " + tableName + " " +
+ "FROM orders");
+ try {
+ gatherStats(
+ tableName,
+ ImmutableList.of(
+ "CASE_UNQUOTED_UPPER",
+ "case_unquoted_lower",
+ "cASe_uNQuoTeD_miXED",
+ "CASE_QUOTED_UPPER",
+ "case_quoted_lower",
+ "CasE_QuoTeD_miXED"));
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('case_unquoted_upper', null, 15000, 0, null, null, null)," +
+ "('case_unquoted_lower', null, 1000, 0, null, null, null)," +
+ "('case_unquoted_mixed', 30000, 3, 0, null, null, null)," +
+ "('case_quoted_upper', null, 14996, 0, null, null, null)," +
+ "('case_quoted_lower', null, 2401, 0, null, null, null)," +
+ "('case_quoted_mixed', 252376, 5, 0, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)");
+ }
+ finally {
+ sqlServer.execute("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ @Test
+ public void testNumericCornerCases()
+ {
+ try (TestTable table = fromColumns(
+ getQueryRunner()::execute,
+ "test_numeric_corner_cases_",
+ ImmutableMap.>builder()
+// TODO infinity and NaNs are not supported by SQLServer
+// .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()"))
+// .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()"))
+// .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()"))
+// .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0"))
+// .put("nans_only double", List.of("nan()", "nan()"))
+// .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0"))
+ .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3
+ .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456"))
+ .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6"))
+ .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678"))
+ .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678"))
+ .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8"))
+ .buildOrThrow(),
+ "null")) {
+ gatherStats(table.getName());
+ assertQuery(
+ "SHOW STATS FOR " + table.getName(),
+ "VALUES " +
+// TODO infinity and NaNs are not supported by SQLServer
+// "('only_negative_infinity', null, 1, 0, null, null, null)," +
+// "('only_positive_infinity', null, 1, 0, null, null, null)," +
+// "('mixed_infinities', null, 2, 0, null, null, null)," +
+// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," +
+// "('nans_only', null, 1.0, 0.5, null, null, null)," +
+// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," +
+ "('large_doubles', null, 2.0, 0.0, null, null, null)," +
+ "('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
+ "('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_middle', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
+ "(null, null, null, null, 2, null, null)");
+ }
+ }
+
+ @Test
+ public void testShowStatsAfterCreateIndex()
+ {
+ String tableName = "test_stats_create_index";
+ assertUpdate("DROP TABLE IF EXISTS " + tableName);
+ computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName));
+
+ String expected = "VALUES " +
+ "('orderkey', null, 15000, 0, null, null, null)," +
+ "('custkey', null, 1000, 0, null, null, null)," +
+ "('orderstatus', 30000, 3, 0, null, null, null)," +
+ "('totalprice', null, 14996, 0, null, null, null)," +
+ "('orderdate', null, 2401, 0, null, null, null)," +
+ "('orderpriority', 252376, 5, 0, null, null, null)," +
+ "('clerk', 450000, 1000, 0, null, null, null)," +
+ "('shippriority', null, 1, 0, null, null, null)," +
+ "('comment', 1454727, 14994, 0, null, null, null)," +
+ "(null, null, null, null, 15000, null, null)";
+
+ try {
+ gatherStats(tableName);
+ assertQuery("SHOW STATS FOR " + tableName, expected);
+
+ // CREATE INDEX statement updates sys.partitions table
+ sqlServer.execute(format("CREATE INDEX idx ON %s (orderkey)", tableName));
+ sqlServer.execute(format("CREATE UNIQUE INDEX unique_index ON %s (orderkey)", tableName));
+ sqlServer.execute(format("CREATE CLUSTERED INDEX clustered_index ON %s (orderkey)", tableName));
+ sqlServer.execute(format("CREATE NONCLUSTERED INDEX non_clustered_index ON %s (orderkey)", tableName));
+
+ assertQuery("SHOW STATS FOR " + tableName, expected);
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Override
+ protected void gatherStats(String tableName)
+ {
+ List columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName))
+ .map(row -> (String) row.getField(0))
+ .collect(toImmutableList());
+ gatherStats(tableName, columnNames);
+ }
+
+ private void gatherStats(String tableName, List columnNames)
+ {
+ for (Object columnName : columnNames) {
+ sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName));
+ }
+ sqlServer.execute("UPDATE STATISTICS " + tableName);
+ }
+}