diff --git a/plugin/trino-mariadb/pom.xml b/plugin/trino-mariadb/pom.xml index 2486e2b90d37..606437f9ed4a 100644 --- a/plugin/trino-mariadb/pom.xml +++ b/plugin/trino-mariadb/pom.xml @@ -33,6 +33,11 @@ configuration + + io.airlift + log + + io.trino trino-base-jdbc @@ -48,6 +53,11 @@ jakarta.validation-api + + org.jdbi + jdbi3-core + + org.mariadb.jdbc mariadb-java-client @@ -89,12 +99,6 @@ provided - - io.airlift - log - runtime - - io.airlift log-manager diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java index 91054052e412..23dbb45e3310 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java @@ -13,8 +13,10 @@ */ package io.trino.plugin.mariadb; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.airlift.log.Logger; import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; @@ -27,6 +29,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongWriteFunction; @@ -56,6 +59,9 @@ import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -63,6 +69,8 @@ import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -75,13 +83,18 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.function.BiFunction; import java.util.stream.Stream; +import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding; @@ -137,12 +150,15 @@ import static java.lang.Math.min; import static java.lang.String.format; import static java.lang.String.join; +import static java.util.Map.entry; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; public class MariaDbClient extends BaseJdbcClient { + private static final Logger log = Logger.get(MariaDbClient.class); + private static final int MAX_SUPPORTED_DATE_TIME_PRECISION = 6; // MariaDB driver returns width of time types instead of precision. private static final int ZERO_PRECISION_TIME_COLUMN_SIZE = 10; @@ -156,10 +172,17 @@ public class MariaDbClient // MariaDB Error Codes https://mariadb.com/kb/en/mariadb-error-codes/ private static final int PARSE_ERROR = 1064; + private final boolean statisticsEnabled; private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject - public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) + public MariaDbClient( + BaseJdbcConfig config, + JdbcStatisticsConfig statisticsConfig, + ConnectionFactory connectionFactory, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false); @@ -167,6 +190,7 @@ public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .build(); + this.statisticsEnabled = statisticsConfig.isEnabled(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, ImmutableSet.>builder() @@ -623,6 +647,102 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon .noneMatch(type -> type instanceof CharType || type instanceof VarcharType); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return readTableStatistics(session, handle); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + log.debug("Reading statistics for %s", table); + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + Long rowCount = statisticsDao.getTableRowCount(table); + Long indexMaxCardinality = statisticsDao.getTableMaxColumnIndexCardinality(table); + log.debug("Estimated row count of table %s is %s, and max index cardinality is %s", table, rowCount, indexMaxCardinality); + + if (rowCount != null && rowCount == 0) { + // MariaDB may report 0 row count until a table is analyzed for the first time. + rowCount = null; + } + + if (rowCount == null && indexMaxCardinality == null) { + // Table not found, or is a view, or has no usable statistics + return TableStatistics.empty(); + } + rowCount = max(firstNonNull(rowCount, 0L), firstNonNull(indexMaxCardinality, 0L)); + + TableStatistics.Builder tableStatistics = TableStatistics.builder(); + tableStatistics.setRowCount(Estimate.of(rowCount)); + + // TODO statistics from ANALYZE TABLE (https://mariadb.com/kb/en/engine-independent-table-statistics/) + // Map columnStatistics = statisticsDao.getColumnStatistics(table); + Map columnStatistics = ImmutableMap.of(); + + // TODO add support for histograms https://mariadb.com/kb/en/histogram-based-statistics/ + + // statistics based on existing indexes + Map columnStatisticsFromIndexes = statisticsDao.getColumnIndexStatistics(table); + + if (columnStatistics.isEmpty() && columnStatisticsFromIndexes.isEmpty()) { + log.debug("No column and index statistics read"); + // No more information to work on + return tableStatistics.build(); + } + + for (JdbcColumnHandle column : getColumns(session, table)) { + ColumnStatistics.Builder columnStatisticsBuilder = ColumnStatistics.builder(); + + String columnName = column.getColumnName(); + AnalyzeColumnStatistics analyzeColumnStatistics = columnStatistics.get(columnName); + if (analyzeColumnStatistics != null) { + log.debug("Reading column statistics for %s, %s from analayze's column statistics: %s", table, columnName, analyzeColumnStatistics); + columnStatisticsBuilder.setNullsFraction(Estimate.of(analyzeColumnStatistics.nullsRatio())); + } + + ColumnIndexStatistics columnIndexStatistics = columnStatisticsFromIndexes.get(columnName); + if (columnIndexStatistics != null) { + log.debug("Reading column statistics for %s, %s from index statistics: %s", table, columnName, columnIndexStatistics); + columnStatisticsBuilder.setDistinctValuesCount(Estimate.of(columnIndexStatistics.cardinality())); + + if (!columnIndexStatistics.nullable()) { + double knownNullFraction = columnStatisticsBuilder.build().getNullsFraction().getValue(); + if (knownNullFraction > 0) { + log.warn("Inconsistent statistics, null fraction for a column %s, %s, that is not nullable according to index statistics: %s", table, columnName, knownNullFraction); + } + columnStatisticsBuilder.setNullsFraction(Estimate.zero()); + } + + // row count from INFORMATION_SCHEMA.TABLES may be very inaccurate + rowCount = max(rowCount, columnIndexStatistics.cardinality()); + } + + tableStatistics.setColumnStatistics(column, columnStatisticsBuilder.build()); + } + + tableStatistics.setRowCount(Estimate.of(rowCount)); + return tableStatistics.build(); + } + } + private static LongWriteFunction dateWriteFunction() { return (statement, index, day) -> statement.setString(index, DATE_FORMATTER.format(LocalDate.ofEpochDay(day))); @@ -650,4 +770,101 @@ private static Optional getUnsignedMapping(JdbcTypeHandle typeHan return Optional.empty(); } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Long getTableRowCount(JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + return handle.createQuery(""" + SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name + AND TABLE_TYPE = 'BASE TABLE' + """) + .bind("schema", remoteTableName.getCatalogName().orElse(null)) + .bind("table_name", remoteTableName.getTableName()) + .mapTo(Long.class) + .findOne() + .orElse(null); + } + + Long getTableMaxColumnIndexCardinality(JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + return handle.createQuery(""" + SELECT max(CARDINALITY) AS row_count FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name + """) + .bind("schema", remoteTableName.getCatalogName().orElse(null)) + .bind("table_name", remoteTableName.getTableName()) + .mapTo(Long.class) + .findOne() + .orElse(null); + } + + Map getColumnStatistics(JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + return handle.createQuery(""" + SELECT + column_name, + -- TODO min_value, max_value, + nulls_ratio + FROM mysql.column_stats + WHERE db_name = :database AND TABLE_NAME = :table_name + AND nulls_ratio IS NOT NULL + """) + .bind("database", remoteTableName.getCatalogName().orElse(null)) + .bind("table_name", remoteTableName.getTableName()) + .map((rs, ctx) -> { + String columnName = rs.getString("column_name"); + double nullsRatio = rs.getDouble("nulls_ratio"); + return entry(columnName, new AnalyzeColumnStatistics(nullsRatio)); + }) + .collect(toImmutableMap(Entry::getKey, Entry::getValue)); + } + + Map getColumnIndexStatistics(JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + return handle.createQuery(""" + SELECT + COLUMN_NAME, + MAX(NULLABLE) AS NULLABLE, + MAX(CARDINALITY) AS CARDINALITY + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name + AND SEQ_IN_INDEX = 1 -- first column in the index + AND SUB_PART IS NULL -- ignore cases where only a column prefix is indexed + AND CARDINALITY IS NOT NULL -- CARDINALITY might be null (https://stackoverflow.com/a/42242729/65458) + AND CARDINALITY != 0 -- CARDINALITY is initially 0 until analyzed + GROUP BY COLUMN_NAME -- there might be multiple indexes on a column + """) + .bind("schema", remoteTableName.getCatalogName().orElse(null)) + .bind("table_name", remoteTableName.getTableName()) + .map((rs, ctx) -> { + String columnName = rs.getString("COLUMN_NAME"); + + boolean nullable = rs.getString("NULLABLE").equalsIgnoreCase("YES"); + checkState(!rs.wasNull(), "NULLABLE is null"); + + long cardinality = rs.getLong("CARDINALITY"); + checkState(!rs.wasNull(), "CARDINALITY is null"); + + return entry(columnName, new ColumnIndexStatistics(nullable, cardinality)); + }) + .collect(toImmutableMap(Entry::getKey, Entry::getValue)); + } + } + + private record AnalyzeColumnStatistics(double nullsRatio) {} + + private record ColumnIndexStatistics(boolean nullable, long cardinality) {} } diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java index c516a8ee9195..e06913b26e39 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java @@ -25,6 +25,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.function.table.ConnectorTableFunction; @@ -43,6 +44,7 @@ public void configure(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(MariaDbClient.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(MariaDbJdbcConfig.class); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); binder.install(new DecimalModule()); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); } diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java new file mode 100644 index 000000000000..b585ce4c59cf --- /dev/null +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mariadb; + +import io.trino.testing.MaterializedRow; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import static java.lang.String.format; + +public abstract class BaseMariaDbTableIndexStatisticsTest + extends BaseMariaDbTableStatisticsTest +{ + protected BaseMariaDbTableIndexStatisticsTest(String dockerImageName) + { + super( + dockerImageName, + nullFraction -> 0.1, // Without mysql.column_stats we have no way of knowing real null fraction, 10% is just a "wild guess" + varcharNdv -> null); // Without mysql.column_stats we don't know cardinality for varchar columns + } + + @Override + protected void gatherStats(String tableName) + { + for (MaterializedRow row : computeActual("SHOW COLUMNS FROM " + tableName)) { + String columnName = (String) row.getField(0); + String columnType = (String) row.getField(1); + if (columnType.startsWith("varchar")) { + continue; + } + executeInMariaDb(format("CREATE INDEX %2$s ON %1$s (%2$s)", tableName, columnName).replace("\"", "`")); + } + executeInMariaDb("ANALYZE TABLE " + tableName.replace("\"", "`")); + } + + @Test + @Override + public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithPredicatePushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithVarcharPredicatePushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithLimitPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithTopNPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithDistinctPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithDistinctLimitPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithAggregationPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithSimpleJoinPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } + + @Test + @Override + public void testStatsWithJoinPushdown() + { + // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions + throw new SkipException("Test to be implemented"); + } +} diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java new file mode 100644 index 000000000000..45d230e6e1c6 --- /dev/null +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java @@ -0,0 +1,443 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mariadb; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest; +import io.trino.testing.MaterializedResult; +import io.trino.testing.MaterializedRow; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.assertj.core.api.AbstractDoubleAssert; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.plugin.mariadb.MariaDbQueryRunner.createMariaDbQueryRunner; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.lang.Math.min; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.withinPercentage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; + +public abstract class BaseMariaDbTableStatisticsTest + extends BaseJdbcTableStatisticsTest +{ + protected final String dockerImageName; + protected final Function nullFractionToExpected; + protected final Function varcharNdvToExpected; + protected TestingMariaDbServer mariaDbServer; + + protected BaseMariaDbTableStatisticsTest( + String dockerImageName, + Function nullFractionToExpected, + Function varcharNdvToExpected) + { + this.dockerImageName = requireNonNull(dockerImageName, "dockerImageName is null"); + this.nullFractionToExpected = requireNonNull(nullFractionToExpected, "nullFractionToExpected is null"); + this.varcharNdvToExpected = requireNonNull(varcharNdvToExpected, "varcharNdvToExpected is null"); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + mariaDbServer = closeAfterClass(new TestingMariaDbServer(dockerImageName)); + + return createMariaDbQueryRunner( + mariaDbServer, + Map.of(), + Map.of("case-insensitive-name-matching", "true"), + List.of(ORDERS)); + } + + @Test + @Override + public void testNotAnalyzed() + { + String tableName = "test_not_analyzed_" + randomNameSuffix(); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName); + Double cardinality = getTableCardinalityFromStats(statsResult); + + if (cardinality != null) { + // TABLE_ROWS in INFORMATION_SCHEMA.TABLES can be estimated as a very small number + assertThat(cardinality).isBetween(1d, 15000 * 1.5); + } + + assertColumnStats(statsResult, new MapBuilder() + .put("orderkey", null) + .put("custkey", null) + .put("orderstatus", null) + .put("totalprice", null) + .put("orderdate", null) + .put("orderpriority", null) + .put("clerk", null) + .put("shippriority", null) + .put("comment", null) + .build()); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + @Override + public void testBasic() + { + String tableName = "test_stats_orders_" + randomNameSuffix(); + computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName)); + try { + gatherStats(tableName); + MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName); + assertColumnStats(statsResult, new MapBuilder() + .put("orderkey", 15000) + .put("custkey", 1000) + .put("orderstatus", varcharNdvToExpected.apply(3)) + .put("totalprice", 14996) + .put("orderdate", 2401) + .put("orderpriority", varcharNdvToExpected.apply(5)) + .put("clerk", varcharNdvToExpected.apply(1000)) + .put("shippriority", 1) + .put("comment", varcharNdvToExpected.apply(14995)) + .build()); + assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20)); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + @Override + public void testAllNulls() + { + String tableName = "test_stats_table_all_nulls_" + randomNameSuffix(); + computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName)); + try { + computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL", tableName)); + gatherStats(tableName); + MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName); + for (MaterializedRow row : statsResult) { + String columnName = (String) row.getField(0); + if (columnName == null) { + // table summary row + return; + } + assertThat(columnName).isIn("orderkey", "custkey", "orderpriority", "comment"); + + Double dataSize = (Double) row.getField(1); + if (dataSize != null) { + assertThat(dataSize).as("Data size for " + columnName) + .isEqualTo(0); + } + + if ((columnName.equals("orderpriority") || columnName.equals("comment")) && varcharNdvToExpected.apply(2) == null) { + assertNull(row.getField(2), "NDV for " + columnName); + assertNull(row.getField(3), "null fraction for " + columnName); + } + else { + assertNotNull(row.getField(2), "NDV for " + columnName); + assertThat((Double) row.getField(2)).as("NDV for " + columnName).isBetween(0.0, 2.0); + assertEquals(row.getField(3), nullFractionToExpected.apply(1.0), "null fraction for " + columnName); + } + + assertNull(row.getField(4), "min"); + assertNull(row.getField(5), "max"); + } + double cardinality = getTableCardinalityFromStats(statsResult); + if (cardinality != 15.0) { + // sometimes all-NULLs tables are reported as containing 0-2 rows + assertThat(cardinality).isBetween(0.0, 2.0); + } + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + @Override + public void testNullsFraction() + { + String tableName = "test_stats_table_with_nulls_" + randomNameSuffix(); + assertUpdate("" + + "CREATE TABLE " + tableName + " AS " + + "SELECT " + + " orderkey, " + + " if(orderkey % 3 = 0, NULL, custkey) custkey, " + + " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " + + "FROM tpch.tiny.orders", + 15000); + try { + gatherStats(tableName); + MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName); + assertColumnStats( + statsResult, + new MapBuilder() + .put("orderkey", 15000) + .put("custkey", 1000) + .put("orderpriority", varcharNdvToExpected.apply(5)) + .build(), + new MapBuilder() + .put("orderkey", nullFractionToExpected.apply(0.0)) + .put("custkey", nullFractionToExpected.apply(1.0 / 3)) + .put("orderpriority", nullFractionToExpected.apply(1.0 / 5)) + .build()); + assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20)); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + @Override + public void testAverageColumnLength() + { + throw new SkipException("MariaDB connector does not report average column length"); + } + + @Test + @Override + public void testPartitionedTable() + { + throw new SkipException("Not implemented"); // TODO + } + + @Test + @Override + public void testView() + { + String tableName = "test_stats_view_" + randomNameSuffix(); + executeInMariaDb("CREATE OR REPLACE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders"); + try { + assertQuery( + "SHOW STATS FOR " + tableName, + "VALUES " + + "('orderkey', null, null, null, null, null, null)," + + "('custkey', null, null, null, null, null, null)," + + "('orderpriority', null, null, null, null, null, null)," + + "('comment', null, null, null, null, null, null)," + + "(null, null, null, null, null, null, null)"); + // It's not possible to ANALYZE a VIEW in MariaDB + } + finally { + executeInMariaDb("DROP VIEW " + tableName); + } + } + + @Test + @Override + public void testMaterializedView() + { + throw new SkipException(""); // TODO is there a concept like materialized view in MariaDB? + } + + @Test(dataProvider = "testCaseColumnNamesDataProvider") + @Override + public void testCaseColumnNames(String tableName) + { + executeInMariaDb(("" + + "CREATE TABLE " + tableName + " " + + "AS SELECT " + + " orderkey AS CASE_UNQUOTED_UPPER, " + + " custkey AS case_unquoted_lower, " + + " orderstatus AS cASe_uNQuoTeD_miXED, " + + " totalprice AS \"CASE_QUOTED_UPPER\", " + + " orderdate AS \"case_quoted_lower\"," + + " orderpriority AS \"CasE_QuoTeD_miXED\" " + + "FROM orders") + .replace("\"", "`")); + try { + gatherStats(tableName); + MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName); + assertColumnStats(statsResult, new MapBuilder() + .put("case_unquoted_upper", 15000) + .put("case_unquoted_lower", 1000) + .put("case_unquoted_mixed", varcharNdvToExpected.apply(3)) + .put("case_quoted_upper", 14996) + .put("case_quoted_lower", 2401) + .put("case_quoted_mixed", varcharNdvToExpected.apply(5)) + .build()); + assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20)); + } + finally { + executeInMariaDb("DROP TABLE " + tableName.replace("\"", "`")); + } + } + + @Test + @Override + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + // TODO Infinity and NaNs not supported by MySQL. Are they not supported in MariaDB as well? +// .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) +// .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) +// .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) +// .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) +// .put("nans_only double", List.of("nan()", "nan()")) +// .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + // DECIMALS up to precision 30 are supported + .put("long_decimals_big_fraction decimal(30,29)", List.of("-1.23456789012345678901234567890", "1.23456789012345678901234567890")) + .put("long_decimals_middle decimal(30,16)", List.of("-12345678901234.5678901234567890", "12345678901234.5678901234567890")) + .put("long_decimals_big_integral decimal(30,1)", List.of("-12345678901234567890123456789.0", "12345678901234567890123456789.0")) + .buildOrThrow(), + "null")) { + gatherStats(table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + // TODO Infinity and NaNs not supported by MySQL. Are they not supported in MariaDB as well? +// "('only_negative_infinity', null, 1, 0, null, null, null)," + +// "('only_positive_infinity', null, 1, 0, null, null, null)," + +// "('mixed_infinities', null, 2, 0, null, null, null)," + +// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + +// "('nans_only', null, 1.0, 0.5, null, null, null)," + +// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.0, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," + + "(null, null, null, null, 2, null, null)"); + } + } + + protected void executeInMariaDb(String sql) + { + mariaDbServer.execute(sql); + } + + protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs) + { + assertColumnStats(statsResult, columnNdvs, nullFractionToExpected.apply(0.0)); + } + + protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs, double nullFraction) + { + Map columnNullFractions = new HashMap<>(); + columnNdvs.forEach((columnName, ndv) -> columnNullFractions.put(columnName, ndv == null ? null : nullFraction)); + + assertColumnStats(statsResult, columnNdvs, columnNullFractions); + } + + protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs, Map columnNullFractions) + { + assertEquals(columnNdvs.keySet(), columnNullFractions.keySet()); + List reportedColumns = stream(statsResult) + .map(row -> row.getField(0)) // column name + .filter(Objects::nonNull) + .map(String.class::cast) + .collect(toImmutableList()); + assertThat(reportedColumns) + .containsOnlyOnce(columnNdvs.keySet().toArray(new String[0])); + + Double tableCardinality = getTableCardinalityFromStats(statsResult); + for (MaterializedRow row : statsResult) { + if (row.getField(0) == null) { + continue; + } + String columnName = (String) row.getField(0); + verify(columnNdvs.containsKey(columnName)); + Integer expectedNdv = columnNdvs.get(columnName); + verify(columnNullFractions.containsKey(columnName)); + Double expectedNullFraction = columnNullFractions.get(columnName); + + Double dataSize = (Double) row.getField(1); + if (dataSize != null) { + assertThat(dataSize).as("Data size for " + columnName) + .isEqualTo(0); + } + + Double distinctCount = (Double) row.getField(2); + Double nullsFraction = (Double) row.getField(3); + AbstractDoubleAssert ndvAssertion = assertThat(distinctCount).as("NDV for " + columnName); + if (expectedNdv == null) { + ndvAssertion.isNull(); + assertNull(nullsFraction, "null fraction for " + columnName); + } + else { + ndvAssertion.isBetween(expectedNdv * 0.5, min(expectedNdv * 4.0, tableCardinality)); // [-50%, +300%] but no more than row count + AbstractDoubleAssert nullsAssertion = assertThat(nullsFraction).as("Null fraction for " + columnName); + if (distinctCount.compareTo(tableCardinality) >= 0) { + nullsAssertion.isEqualTo(0); + } + else { + double maxNullsFraction = (tableCardinality - distinctCount) / tableCardinality; + expectedNullFraction = Math.min(expectedNullFraction, maxNullsFraction); + nullsAssertion.isBetween(expectedNullFraction * 0.4, expectedNullFraction * 1.1); + } + } + + assertNull(row.getField(4), "min"); + assertNull(row.getField(5), "max"); + } + } + + protected static Double getTableCardinalityFromStats(MaterializedResult statsResult) + { + MaterializedRow lastRow = statsResult.getMaterializedRows().get(statsResult.getRowCount() - 1); + assertNull(lastRow.getField(0)); + assertNull(lastRow.getField(1)); + assertNull(lastRow.getField(2)); + assertNull(lastRow.getField(3)); + assertNull(lastRow.getField(5)); + assertNull(lastRow.getField(6)); + assertEquals(lastRow.getFieldCount(), 7); + return ((Double) lastRow.getField(4)); + } + + protected static class MapBuilder + { + private final Map map = new HashMap<>(); + + public MapBuilder put(K key, V value) + { + checkArgument(!map.containsKey(key), "Key already present: %s", key); + map.put(requireNonNull(key, "key is null"), value); + return this; + } + + public Map build() + { + return new HashMap<>(map); + } + } +} diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java index 92969e2c7d53..b184d7a8008b 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java @@ -20,6 +20,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.AggregateFunction; @@ -59,6 +60,7 @@ public class TestMariaDbClient private static final JdbcClient JDBC_CLIENT = new MariaDbClient( new BaseJdbcConfig(), + new JdbcStatisticsConfig(), session -> { throw new UnsupportedOperationException(); }, diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java new file mode 100644 index 000000000000..d61c36705f5c --- /dev/null +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mariadb; + +public class TestMariaDbTableIndexStatistics + extends BaseMariaDbTableIndexStatisticsTest +{ + public TestMariaDbTableIndexStatistics() + { + super(TestingMariaDbServer.DEFAULT_VERSION); + } +} diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java new file mode 100644 index 000000000000..3ebd8caff96e --- /dev/null +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mariadb; + +public class TestMariaDbTableIndexStatisticsLatest + extends BaseMariaDbTableIndexStatisticsTest +{ + public TestMariaDbTableIndexStatisticsLatest() + { + super(TestingMariaDbServer.LATEST_VERSION); + } +} diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java index 7c6c80cf80a5..c9796befa167 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java @@ -46,17 +46,19 @@ public TestingMariaDbServer(String tag) // explicit-defaults-for-timestamp: 1 is ON, the default set is 0 (OFF) container.withCommand("--character-set-server", "utf8mb4", "--explicit-defaults-for-timestamp=1"); container.start(); - execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()), "root", container.getPassword()); - } - public void execute(String sql) - { - execute(sql, getUsername(), getPassword()); + try (Connection connection = DriverManager.getConnection(getJdbcUrl(), "root", container.getPassword()); + Statement statement = connection.createStatement()) { + statement.execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername())); + } + catch (SQLException e) { + throw new RuntimeException(e); + } } - private void execute(String sql, String user, String password) + public void execute(String sql) { - try (Connection connection = DriverManager.getConnection(getJdbcUrl(), user, password); + try (Connection connection = container.createConnection(""); Statement statement = connection.createStatement()) { statement.execute(sql); } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java index 1d24da6fb2f6..3dce3fcd9108 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java @@ -20,8 +20,6 @@ import io.trino.testing.QueryRunner; import io.trino.testing.sql.TestTable; import org.assertj.core.api.AbstractDoubleAssert; -import org.jdbi.v3.core.Handle; -import org.jdbi.v3.core.Jdbi; import org.testng.SkipException; import org.testng.annotations.Test; @@ -170,7 +168,7 @@ public void testAllNulls() } else { assertNotNull(row.getField(2), "NDV for " + columnName); - assertThat(((Number) row.getField(2)).doubleValue()).as("NDV for " + columnName).isBetween(0.0, 2.0); + assertThat((Double) row.getField(2)).as("NDV for " + columnName).isBetween(0.0, 2.0); assertEquals(row.getField(3), nullFractionToExpected.apply(1.0), "null fraction for " + columnName); } @@ -347,10 +345,7 @@ public void testNumericCornerCases() protected void executeInMysql(String sql) { - try (Handle handle = Jdbi.open(() -> mysqlServer.createConnection())) { - handle.execute("USE tpch"); - handle.execute(sql); - } + mysqlServer.execute(sql); } protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs) @@ -430,7 +425,7 @@ protected static double getTableCardinalityFromStats(MaterializedResult statsRes assertNull(lastRow.getField(6)); assertEquals(lastRow.getFieldCount(), 7); assertNotNull(lastRow.getField(4)); - return ((Number) lastRow.getField(4)).doubleValue(); + return (Double) lastRow.getField(4); } protected static class MapBuilder diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java index e579751c97cd..60a625d9949d 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java @@ -18,8 +18,6 @@ import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; import io.trino.testing.MaterializedRow; import io.trino.testing.QueryRunner; -import org.jdbi.v3.core.Handle; -import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.Test; import static io.trino.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; @@ -73,9 +71,6 @@ protected void gatherStats(String tableName) protected void onRemoteDatabase(String sql) { - try (Handle handle = Jdbi.open(() -> mySqlServer.createConnection())) { - handle.execute("USE tpch"); - handle.execute(sql); - } + mySqlServer.execute(sql); } } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java index 48aec3b77fdf..af3adaa9ed7b 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java @@ -70,7 +70,14 @@ public TestingMySqlServer(String dockerImageName, boolean globalTransactionEnabl this.container = container; configureContainer(container); cleanup = startOrReuse(container); - execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()), "root", container.getPassword()); + + try (Connection connection = DriverManager.getConnection(getJdbcUrl(), "root", container.getPassword()); + Statement statement = connection.createStatement()) { + statement.execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername())); + } + catch (SQLException e) { + throw new RuntimeException(e); + } } private void configureContainer(MySQLContainer container) @@ -79,20 +86,9 @@ private void configureContainer(MySQLContainer container) container.addParameter("TC_MY_CNF", null); } - public Connection createConnection() - throws SQLException - { - return container.createConnection(""); - } - public void execute(String sql) { - execute(sql, getUsername(), getPassword()); - } - - public void execute(String sql, String user, String password) - { - try (Connection connection = DriverManager.getConnection(getJdbcUrl(), user, password); + try (Connection connection = createConnection(); Statement statement = connection.createStatement()) { statement.execute(sql); } @@ -101,6 +97,12 @@ public void execute(String sql, String user, String password) } } + public Connection createConnection() + throws SQLException + { + return container.createConnection(""); + } + public String getUsername() { return container.getUsername();