diff --git a/plugin/trino-mariadb/pom.xml b/plugin/trino-mariadb/pom.xml
index 2486e2b90d37..606437f9ed4a 100644
--- a/plugin/trino-mariadb/pom.xml
+++ b/plugin/trino-mariadb/pom.xml
@@ -33,6 +33,11 @@
configuration
+
+ io.airlift
+ log
+
+
io.trino
trino-base-jdbc
@@ -48,6 +53,11 @@
jakarta.validation-api
+
+ org.jdbi
+ jdbi3-core
+
+
org.mariadb.jdbc
mariadb-java-client
@@ -89,12 +99,6 @@
provided
-
- io.airlift
- log
- runtime
-
-
io.airlift
log-manager
diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java
index 91054052e412..23dbb45e3310 100644
--- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java
+++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java
@@ -13,8 +13,10 @@
*/
package io.trino.plugin.mariadb;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
+import io.airlift.log.Logger;
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
@@ -27,6 +29,7 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcSortItem;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongWriteFunction;
@@ -56,6 +59,9 @@
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
+import io.trino.spi.statistics.ColumnStatistics;
+import io.trino.spi.statistics.Estimate;
+import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
@@ -63,6 +69,8 @@
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
+import org.jdbi.v3.core.Handle;
+import org.jdbi.v3.core.Jdbi;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
@@ -75,13 +83,18 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Stream;
+import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.emptyToNull;
+import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verify;
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding;
@@ -137,12 +150,15 @@
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.String.join;
+import static java.util.Map.entry;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
public class MariaDbClient
extends BaseJdbcClient
{
+ private static final Logger log = Logger.get(MariaDbClient.class);
+
private static final int MAX_SUPPORTED_DATE_TIME_PRECISION = 6;
// MariaDB driver returns width of time types instead of precision.
private static final int ZERO_PRECISION_TIME_COLUMN_SIZE = 10;
@@ -156,10 +172,17 @@ public class MariaDbClient
// MariaDB Error Codes https://mariadb.com/kb/en/mariadb-error-codes/
private static final int PARSE_ERROR = 1064;
+ private final boolean statisticsEnabled;
private final AggregateFunctionRewriter aggregateFunctionRewriter;
@Inject
- public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier)
+ public MariaDbClient(
+ BaseJdbcConfig config,
+ JdbcStatisticsConfig statisticsConfig,
+ ConnectionFactory connectionFactory,
+ QueryBuilder queryBuilder,
+ IdentifierMapping identifierMapping,
+ RemoteQueryModifier queryModifier)
{
super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false);
@@ -167,6 +190,7 @@ public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory,
ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.build();
+ this.statisticsEnabled = statisticsConfig.isEnabled();
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
ImmutableSet.>builder()
@@ -623,6 +647,102 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon
.noneMatch(type -> type instanceof CharType || type instanceof VarcharType);
}
+ @Override
+ public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle)
+ {
+ if (!statisticsEnabled) {
+ return TableStatistics.empty();
+ }
+ if (!handle.isNamedRelation()) {
+ return TableStatistics.empty();
+ }
+ try {
+ return readTableStatistics(session, handle);
+ }
+ catch (SQLException | RuntimeException e) {
+ throwIfInstanceOf(e, TrinoException.class);
+ throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e);
+ }
+ }
+
+ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table)
+ throws SQLException
+ {
+ checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table);
+
+ log.debug("Reading statistics for %s", table);
+ try (Connection connection = connectionFactory.openConnection(session);
+ Handle handle = Jdbi.open(connection)) {
+ StatisticsDao statisticsDao = new StatisticsDao(handle);
+
+ Long rowCount = statisticsDao.getTableRowCount(table);
+ Long indexMaxCardinality = statisticsDao.getTableMaxColumnIndexCardinality(table);
+ log.debug("Estimated row count of table %s is %s, and max index cardinality is %s", table, rowCount, indexMaxCardinality);
+
+ if (rowCount != null && rowCount == 0) {
+ // MariaDB may report 0 row count until a table is analyzed for the first time.
+ rowCount = null;
+ }
+
+ if (rowCount == null && indexMaxCardinality == null) {
+ // Table not found, or is a view, or has no usable statistics
+ return TableStatistics.empty();
+ }
+ rowCount = max(firstNonNull(rowCount, 0L), firstNonNull(indexMaxCardinality, 0L));
+
+ TableStatistics.Builder tableStatistics = TableStatistics.builder();
+ tableStatistics.setRowCount(Estimate.of(rowCount));
+
+ // TODO statistics from ANALYZE TABLE (https://mariadb.com/kb/en/engine-independent-table-statistics/)
+ // Map columnStatistics = statisticsDao.getColumnStatistics(table);
+ Map columnStatistics = ImmutableMap.of();
+
+ // TODO add support for histograms https://mariadb.com/kb/en/histogram-based-statistics/
+
+ // statistics based on existing indexes
+ Map columnStatisticsFromIndexes = statisticsDao.getColumnIndexStatistics(table);
+
+ if (columnStatistics.isEmpty() && columnStatisticsFromIndexes.isEmpty()) {
+ log.debug("No column and index statistics read");
+ // No more information to work on
+ return tableStatistics.build();
+ }
+
+ for (JdbcColumnHandle column : getColumns(session, table)) {
+ ColumnStatistics.Builder columnStatisticsBuilder = ColumnStatistics.builder();
+
+ String columnName = column.getColumnName();
+ AnalyzeColumnStatistics analyzeColumnStatistics = columnStatistics.get(columnName);
+ if (analyzeColumnStatistics != null) {
+ log.debug("Reading column statistics for %s, %s from analayze's column statistics: %s", table, columnName, analyzeColumnStatistics);
+ columnStatisticsBuilder.setNullsFraction(Estimate.of(analyzeColumnStatistics.nullsRatio()));
+ }
+
+ ColumnIndexStatistics columnIndexStatistics = columnStatisticsFromIndexes.get(columnName);
+ if (columnIndexStatistics != null) {
+ log.debug("Reading column statistics for %s, %s from index statistics: %s", table, columnName, columnIndexStatistics);
+ columnStatisticsBuilder.setDistinctValuesCount(Estimate.of(columnIndexStatistics.cardinality()));
+
+ if (!columnIndexStatistics.nullable()) {
+ double knownNullFraction = columnStatisticsBuilder.build().getNullsFraction().getValue();
+ if (knownNullFraction > 0) {
+ log.warn("Inconsistent statistics, null fraction for a column %s, %s, that is not nullable according to index statistics: %s", table, columnName, knownNullFraction);
+ }
+ columnStatisticsBuilder.setNullsFraction(Estimate.zero());
+ }
+
+ // row count from INFORMATION_SCHEMA.TABLES may be very inaccurate
+ rowCount = max(rowCount, columnIndexStatistics.cardinality());
+ }
+
+ tableStatistics.setColumnStatistics(column, columnStatisticsBuilder.build());
+ }
+
+ tableStatistics.setRowCount(Estimate.of(rowCount));
+ return tableStatistics.build();
+ }
+ }
+
private static LongWriteFunction dateWriteFunction()
{
return (statement, index, day) -> statement.setString(index, DATE_FORMATTER.format(LocalDate.ofEpochDay(day)));
@@ -650,4 +770,101 @@ private static Optional getUnsignedMapping(JdbcTypeHandle typeHan
return Optional.empty();
}
+
+ private static class StatisticsDao
+ {
+ private final Handle handle;
+
+ public StatisticsDao(Handle handle)
+ {
+ this.handle = requireNonNull(handle, "handle is null");
+ }
+
+ Long getTableRowCount(JdbcTableHandle table)
+ {
+ RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
+ return handle.createQuery("""
+ SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES
+ WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
+ AND TABLE_TYPE = 'BASE TABLE'
+ """)
+ .bind("schema", remoteTableName.getCatalogName().orElse(null))
+ .bind("table_name", remoteTableName.getTableName())
+ .mapTo(Long.class)
+ .findOne()
+ .orElse(null);
+ }
+
+ Long getTableMaxColumnIndexCardinality(JdbcTableHandle table)
+ {
+ RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
+ return handle.createQuery("""
+ SELECT max(CARDINALITY) AS row_count FROM INFORMATION_SCHEMA.STATISTICS
+ WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
+ """)
+ .bind("schema", remoteTableName.getCatalogName().orElse(null))
+ .bind("table_name", remoteTableName.getTableName())
+ .mapTo(Long.class)
+ .findOne()
+ .orElse(null);
+ }
+
+ Map getColumnStatistics(JdbcTableHandle table)
+ {
+ RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
+ return handle.createQuery("""
+ SELECT
+ column_name,
+ -- TODO min_value, max_value,
+ nulls_ratio
+ FROM mysql.column_stats
+ WHERE db_name = :database AND TABLE_NAME = :table_name
+ AND nulls_ratio IS NOT NULL
+ """)
+ .bind("database", remoteTableName.getCatalogName().orElse(null))
+ .bind("table_name", remoteTableName.getTableName())
+ .map((rs, ctx) -> {
+ String columnName = rs.getString("column_name");
+ double nullsRatio = rs.getDouble("nulls_ratio");
+ return entry(columnName, new AnalyzeColumnStatistics(nullsRatio));
+ })
+ .collect(toImmutableMap(Entry::getKey, Entry::getValue));
+ }
+
+ Map getColumnIndexStatistics(JdbcTableHandle table)
+ {
+ RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
+ return handle.createQuery("""
+ SELECT
+ COLUMN_NAME,
+ MAX(NULLABLE) AS NULLABLE,
+ MAX(CARDINALITY) AS CARDINALITY
+ FROM INFORMATION_SCHEMA.STATISTICS
+ WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
+ AND SEQ_IN_INDEX = 1 -- first column in the index
+ AND SUB_PART IS NULL -- ignore cases where only a column prefix is indexed
+ AND CARDINALITY IS NOT NULL -- CARDINALITY might be null (https://stackoverflow.com/a/42242729/65458)
+ AND CARDINALITY != 0 -- CARDINALITY is initially 0 until analyzed
+ GROUP BY COLUMN_NAME -- there might be multiple indexes on a column
+ """)
+ .bind("schema", remoteTableName.getCatalogName().orElse(null))
+ .bind("table_name", remoteTableName.getTableName())
+ .map((rs, ctx) -> {
+ String columnName = rs.getString("COLUMN_NAME");
+
+ boolean nullable = rs.getString("NULLABLE").equalsIgnoreCase("YES");
+ checkState(!rs.wasNull(), "NULLABLE is null");
+
+ long cardinality = rs.getLong("CARDINALITY");
+ checkState(!rs.wasNull(), "CARDINALITY is null");
+
+ return entry(columnName, new ColumnIndexStatistics(nullable, cardinality));
+ })
+ .collect(toImmutableMap(Entry::getKey, Entry::getValue));
+ }
+ }
+
+ private record AnalyzeColumnStatistics(double nullsRatio) {}
+
+ private record ColumnIndexStatistics(boolean nullable, long cardinality) {}
}
diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java
index c516a8ee9195..e06913b26e39 100644
--- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java
+++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClientModule.java
@@ -25,6 +25,7 @@
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import io.trino.plugin.jdbc.ptf.Query;
import io.trino.spi.function.table.ConnectorTableFunction;
@@ -43,6 +44,7 @@ public void configure(Binder binder)
{
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(MariaDbClient.class).in(Scopes.SINGLETON);
configBinder(binder).bindConfig(MariaDbJdbcConfig.class);
+ configBinder(binder).bindConfig(JdbcStatisticsConfig.class);
binder.install(new DecimalModule());
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
}
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java
new file mode 100644
index 000000000000..b585ce4c59cf
--- /dev/null
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableIndexStatisticsTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.mariadb;
+
+import io.trino.testing.MaterializedRow;
+import org.testng.SkipException;
+import org.testng.annotations.Test;
+
+import static java.lang.String.format;
+
+public abstract class BaseMariaDbTableIndexStatisticsTest
+ extends BaseMariaDbTableStatisticsTest
+{
+ protected BaseMariaDbTableIndexStatisticsTest(String dockerImageName)
+ {
+ super(
+ dockerImageName,
+ nullFraction -> 0.1, // Without mysql.column_stats we have no way of knowing real null fraction, 10% is just a "wild guess"
+ varcharNdv -> null); // Without mysql.column_stats we don't know cardinality for varchar columns
+ }
+
+ @Override
+ protected void gatherStats(String tableName)
+ {
+ for (MaterializedRow row : computeActual("SHOW COLUMNS FROM " + tableName)) {
+ String columnName = (String) row.getField(0);
+ String columnType = (String) row.getField(1);
+ if (columnType.startsWith("varchar")) {
+ continue;
+ }
+ executeInMariaDb(format("CREATE INDEX %2$s ON %1$s (%2$s)", tableName, columnName).replace("\"", "`"));
+ }
+ executeInMariaDb("ANALYZE TABLE " + tableName.replace("\"", "`"));
+ }
+
+ @Test
+ @Override
+ public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithPredicatePushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithVarcharPredicatePushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithLimitPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithTopNPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithDistinctPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithDistinctLimitPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithAggregationPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithSimpleJoinPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+
+ @Test
+ @Override
+ public void testStatsWithJoinPushdown()
+ {
+ // TODO (https://github.com/trinodb/trino/issues/11664) implement the test for MariaDB, with permissive approximate assertions
+ throw new SkipException("Test to be implemented");
+ }
+}
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java
new file mode 100644
index 000000000000..45d230e6e1c6
--- /dev/null
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbTableStatisticsTest.java
@@ -0,0 +1,443 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.mariadb;
+
+import com.google.common.collect.ImmutableMap;
+import io.trino.plugin.jdbc.BaseJdbcTableStatisticsTest;
+import io.trino.testing.MaterializedResult;
+import io.trino.testing.MaterializedRow;
+import io.trino.testing.QueryRunner;
+import io.trino.testing.sql.TestTable;
+import org.assertj.core.api.AbstractDoubleAssert;
+import org.testng.SkipException;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Verify.verify;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.Streams.stream;
+import static io.trino.plugin.mariadb.MariaDbQueryRunner.createMariaDbQueryRunner;
+import static io.trino.testing.TestingNames.randomNameSuffix;
+import static io.trino.testing.sql.TestTable.fromColumns;
+import static io.trino.tpch.TpchTable.ORDERS;
+import static java.lang.Math.min;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.withinPercentage;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
+
+public abstract class BaseMariaDbTableStatisticsTest
+ extends BaseJdbcTableStatisticsTest
+{
+ protected final String dockerImageName;
+ protected final Function nullFractionToExpected;
+ protected final Function varcharNdvToExpected;
+ protected TestingMariaDbServer mariaDbServer;
+
+ protected BaseMariaDbTableStatisticsTest(
+ String dockerImageName,
+ Function nullFractionToExpected,
+ Function varcharNdvToExpected)
+ {
+ this.dockerImageName = requireNonNull(dockerImageName, "dockerImageName is null");
+ this.nullFractionToExpected = requireNonNull(nullFractionToExpected, "nullFractionToExpected is null");
+ this.varcharNdvToExpected = requireNonNull(varcharNdvToExpected, "varcharNdvToExpected is null");
+ }
+
+ @Override
+ protected QueryRunner createQueryRunner()
+ throws Exception
+ {
+ mariaDbServer = closeAfterClass(new TestingMariaDbServer(dockerImageName));
+
+ return createMariaDbQueryRunner(
+ mariaDbServer,
+ Map.of(),
+ Map.of("case-insensitive-name-matching", "true"),
+ List.of(ORDERS));
+ }
+
+ @Test
+ @Override
+ public void testNotAnalyzed()
+ {
+ String tableName = "test_not_analyzed_" + randomNameSuffix();
+ computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName));
+ try {
+ MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName);
+ Double cardinality = getTableCardinalityFromStats(statsResult);
+
+ if (cardinality != null) {
+ // TABLE_ROWS in INFORMATION_SCHEMA.TABLES can be estimated as a very small number
+ assertThat(cardinality).isBetween(1d, 15000 * 1.5);
+ }
+
+ assertColumnStats(statsResult, new MapBuilder()
+ .put("orderkey", null)
+ .put("custkey", null)
+ .put("orderstatus", null)
+ .put("totalprice", null)
+ .put("orderdate", null)
+ .put("orderpriority", null)
+ .put("clerk", null)
+ .put("shippriority", null)
+ .put("comment", null)
+ .build());
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Test
+ @Override
+ public void testBasic()
+ {
+ String tableName = "test_stats_orders_" + randomNameSuffix();
+ computeActual(format("CREATE TABLE %s AS SELECT * FROM tpch.tiny.orders", tableName));
+ try {
+ gatherStats(tableName);
+ MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName);
+ assertColumnStats(statsResult, new MapBuilder()
+ .put("orderkey", 15000)
+ .put("custkey", 1000)
+ .put("orderstatus", varcharNdvToExpected.apply(3))
+ .put("totalprice", 14996)
+ .put("orderdate", 2401)
+ .put("orderpriority", varcharNdvToExpected.apply(5))
+ .put("clerk", varcharNdvToExpected.apply(1000))
+ .put("shippriority", 1)
+ .put("comment", varcharNdvToExpected.apply(14995))
+ .build());
+ assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20));
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Test
+ @Override
+ public void testAllNulls()
+ {
+ String tableName = "test_stats_table_all_nulls_" + randomNameSuffix();
+ computeActual(format("CREATE TABLE %s AS SELECT orderkey, custkey, orderpriority, comment FROM tpch.tiny.orders WHERE false", tableName));
+ try {
+ computeActual(format("INSERT INTO %s (orderkey) VALUES NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL", tableName));
+ gatherStats(tableName);
+ MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName);
+ for (MaterializedRow row : statsResult) {
+ String columnName = (String) row.getField(0);
+ if (columnName == null) {
+ // table summary row
+ return;
+ }
+ assertThat(columnName).isIn("orderkey", "custkey", "orderpriority", "comment");
+
+ Double dataSize = (Double) row.getField(1);
+ if (dataSize != null) {
+ assertThat(dataSize).as("Data size for " + columnName)
+ .isEqualTo(0);
+ }
+
+ if ((columnName.equals("orderpriority") || columnName.equals("comment")) && varcharNdvToExpected.apply(2) == null) {
+ assertNull(row.getField(2), "NDV for " + columnName);
+ assertNull(row.getField(3), "null fraction for " + columnName);
+ }
+ else {
+ assertNotNull(row.getField(2), "NDV for " + columnName);
+ assertThat((Double) row.getField(2)).as("NDV for " + columnName).isBetween(0.0, 2.0);
+ assertEquals(row.getField(3), nullFractionToExpected.apply(1.0), "null fraction for " + columnName);
+ }
+
+ assertNull(row.getField(4), "min");
+ assertNull(row.getField(5), "max");
+ }
+ double cardinality = getTableCardinalityFromStats(statsResult);
+ if (cardinality != 15.0) {
+ // sometimes all-NULLs tables are reported as containing 0-2 rows
+ assertThat(cardinality).isBetween(0.0, 2.0);
+ }
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Test
+ @Override
+ public void testNullsFraction()
+ {
+ String tableName = "test_stats_table_with_nulls_" + randomNameSuffix();
+ assertUpdate("" +
+ "CREATE TABLE " + tableName + " AS " +
+ "SELECT " +
+ " orderkey, " +
+ " if(orderkey % 3 = 0, NULL, custkey) custkey, " +
+ " if(orderkey % 5 = 0, NULL, orderpriority) orderpriority " +
+ "FROM tpch.tiny.orders",
+ 15000);
+ try {
+ gatherStats(tableName);
+ MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName);
+ assertColumnStats(
+ statsResult,
+ new MapBuilder()
+ .put("orderkey", 15000)
+ .put("custkey", 1000)
+ .put("orderpriority", varcharNdvToExpected.apply(5))
+ .build(),
+ new MapBuilder()
+ .put("orderkey", nullFractionToExpected.apply(0.0))
+ .put("custkey", nullFractionToExpected.apply(1.0 / 3))
+ .put("orderpriority", nullFractionToExpected.apply(1.0 / 5))
+ .build());
+ assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20));
+ }
+ finally {
+ assertUpdate("DROP TABLE " + tableName);
+ }
+ }
+
+ @Test
+ @Override
+ public void testAverageColumnLength()
+ {
+ throw new SkipException("MariaDB connector does not report average column length");
+ }
+
+ @Test
+ @Override
+ public void testPartitionedTable()
+ {
+ throw new SkipException("Not implemented"); // TODO
+ }
+
+ @Test
+ @Override
+ public void testView()
+ {
+ String tableName = "test_stats_view_" + randomNameSuffix();
+ executeInMariaDb("CREATE OR REPLACE VIEW " + tableName + " AS SELECT orderkey, custkey, orderpriority, comment FROM orders");
+ try {
+ assertQuery(
+ "SHOW STATS FOR " + tableName,
+ "VALUES " +
+ "('orderkey', null, null, null, null, null, null)," +
+ "('custkey', null, null, null, null, null, null)," +
+ "('orderpriority', null, null, null, null, null, null)," +
+ "('comment', null, null, null, null, null, null)," +
+ "(null, null, null, null, null, null, null)");
+ // It's not possible to ANALYZE a VIEW in MariaDB
+ }
+ finally {
+ executeInMariaDb("DROP VIEW " + tableName);
+ }
+ }
+
+ @Test
+ @Override
+ public void testMaterializedView()
+ {
+ throw new SkipException(""); // TODO is there a concept like materialized view in MariaDB?
+ }
+
+ @Test(dataProvider = "testCaseColumnNamesDataProvider")
+ @Override
+ public void testCaseColumnNames(String tableName)
+ {
+ executeInMariaDb(("" +
+ "CREATE TABLE " + tableName + " " +
+ "AS SELECT " +
+ " orderkey AS CASE_UNQUOTED_UPPER, " +
+ " custkey AS case_unquoted_lower, " +
+ " orderstatus AS cASe_uNQuoTeD_miXED, " +
+ " totalprice AS \"CASE_QUOTED_UPPER\", " +
+ " orderdate AS \"case_quoted_lower\"," +
+ " orderpriority AS \"CasE_QuoTeD_miXED\" " +
+ "FROM orders")
+ .replace("\"", "`"));
+ try {
+ gatherStats(tableName);
+ MaterializedResult statsResult = computeActual("SHOW STATS FOR " + tableName);
+ assertColumnStats(statsResult, new MapBuilder()
+ .put("case_unquoted_upper", 15000)
+ .put("case_unquoted_lower", 1000)
+ .put("case_unquoted_mixed", varcharNdvToExpected.apply(3))
+ .put("case_quoted_upper", 14996)
+ .put("case_quoted_lower", 2401)
+ .put("case_quoted_mixed", varcharNdvToExpected.apply(5))
+ .build());
+ assertThat(getTableCardinalityFromStats(statsResult)).isCloseTo(15000, withinPercentage(20));
+ }
+ finally {
+ executeInMariaDb("DROP TABLE " + tableName.replace("\"", "`"));
+ }
+ }
+
+ @Test
+ @Override
+ public void testNumericCornerCases()
+ {
+ try (TestTable table = fromColumns(
+ getQueryRunner()::execute,
+ "test_numeric_corner_cases_",
+ ImmutableMap.>builder()
+ // TODO Infinity and NaNs not supported by MySQL. Are they not supported in MariaDB as well?
+// .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()"))
+// .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()"))
+// .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()"))
+// .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0"))
+// .put("nans_only double", List.of("nan()", "nan()"))
+// .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0"))
+ .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3
+ .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456"))
+ .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6"))
+ // DECIMALS up to precision 30 are supported
+ .put("long_decimals_big_fraction decimal(30,29)", List.of("-1.23456789012345678901234567890", "1.23456789012345678901234567890"))
+ .put("long_decimals_middle decimal(30,16)", List.of("-12345678901234.5678901234567890", "12345678901234.5678901234567890"))
+ .put("long_decimals_big_integral decimal(30,1)", List.of("-12345678901234567890123456789.0", "12345678901234567890123456789.0"))
+ .buildOrThrow(),
+ "null")) {
+ gatherStats(table.getName());
+ assertQuery(
+ "SHOW STATS FOR " + table.getName(),
+ "VALUES " +
+ // TODO Infinity and NaNs not supported by MySQL. Are they not supported in MariaDB as well?
+// "('only_negative_infinity', null, 1, 0, null, null, null)," +
+// "('only_positive_infinity', null, 1, 0, null, null, null)," +
+// "('mixed_infinities', null, 2, 0, null, null, null)," +
+// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," +
+// "('nans_only', null, 1.0, 0.5, null, null, null)," +
+// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," +
+ "('large_doubles', null, 2.0, 0.0, null, null, null)," +
+ "('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
+ "('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_middle', null, 2.0, 0.0, null, null, null)," +
+ "('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
+ "(null, null, null, null, 2, null, null)");
+ }
+ }
+
+ protected void executeInMariaDb(String sql)
+ {
+ mariaDbServer.execute(sql);
+ }
+
+ protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs)
+ {
+ assertColumnStats(statsResult, columnNdvs, nullFractionToExpected.apply(0.0));
+ }
+
+ protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs, double nullFraction)
+ {
+ Map columnNullFractions = new HashMap<>();
+ columnNdvs.forEach((columnName, ndv) -> columnNullFractions.put(columnName, ndv == null ? null : nullFraction));
+
+ assertColumnStats(statsResult, columnNdvs, columnNullFractions);
+ }
+
+ protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs, Map columnNullFractions)
+ {
+ assertEquals(columnNdvs.keySet(), columnNullFractions.keySet());
+ List reportedColumns = stream(statsResult)
+ .map(row -> row.getField(0)) // column name
+ .filter(Objects::nonNull)
+ .map(String.class::cast)
+ .collect(toImmutableList());
+ assertThat(reportedColumns)
+ .containsOnlyOnce(columnNdvs.keySet().toArray(new String[0]));
+
+ Double tableCardinality = getTableCardinalityFromStats(statsResult);
+ for (MaterializedRow row : statsResult) {
+ if (row.getField(0) == null) {
+ continue;
+ }
+ String columnName = (String) row.getField(0);
+ verify(columnNdvs.containsKey(columnName));
+ Integer expectedNdv = columnNdvs.get(columnName);
+ verify(columnNullFractions.containsKey(columnName));
+ Double expectedNullFraction = columnNullFractions.get(columnName);
+
+ Double dataSize = (Double) row.getField(1);
+ if (dataSize != null) {
+ assertThat(dataSize).as("Data size for " + columnName)
+ .isEqualTo(0);
+ }
+
+ Double distinctCount = (Double) row.getField(2);
+ Double nullsFraction = (Double) row.getField(3);
+ AbstractDoubleAssert> ndvAssertion = assertThat(distinctCount).as("NDV for " + columnName);
+ if (expectedNdv == null) {
+ ndvAssertion.isNull();
+ assertNull(nullsFraction, "null fraction for " + columnName);
+ }
+ else {
+ ndvAssertion.isBetween(expectedNdv * 0.5, min(expectedNdv * 4.0, tableCardinality)); // [-50%, +300%] but no more than row count
+ AbstractDoubleAssert> nullsAssertion = assertThat(nullsFraction).as("Null fraction for " + columnName);
+ if (distinctCount.compareTo(tableCardinality) >= 0) {
+ nullsAssertion.isEqualTo(0);
+ }
+ else {
+ double maxNullsFraction = (tableCardinality - distinctCount) / tableCardinality;
+ expectedNullFraction = Math.min(expectedNullFraction, maxNullsFraction);
+ nullsAssertion.isBetween(expectedNullFraction * 0.4, expectedNullFraction * 1.1);
+ }
+ }
+
+ assertNull(row.getField(4), "min");
+ assertNull(row.getField(5), "max");
+ }
+ }
+
+ protected static Double getTableCardinalityFromStats(MaterializedResult statsResult)
+ {
+ MaterializedRow lastRow = statsResult.getMaterializedRows().get(statsResult.getRowCount() - 1);
+ assertNull(lastRow.getField(0));
+ assertNull(lastRow.getField(1));
+ assertNull(lastRow.getField(2));
+ assertNull(lastRow.getField(3));
+ assertNull(lastRow.getField(5));
+ assertNull(lastRow.getField(6));
+ assertEquals(lastRow.getFieldCount(), 7);
+ return ((Double) lastRow.getField(4));
+ }
+
+ protected static class MapBuilder
+ {
+ private final Map map = new HashMap<>();
+
+ public MapBuilder put(K key, V value)
+ {
+ checkArgument(!map.containsKey(key), "Key already present: %s", key);
+ map.put(requireNonNull(key, "key is null"), value);
+ return this;
+ }
+
+ public Map build()
+ {
+ return new HashMap<>(map);
+ }
+ }
+}
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java
index 92969e2c7d53..b184d7a8008b 100644
--- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java
@@ -20,6 +20,7 @@
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
+import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.AggregateFunction;
@@ -59,6 +60,7 @@ public class TestMariaDbClient
private static final JdbcClient JDBC_CLIENT = new MariaDbClient(
new BaseJdbcConfig(),
+ new JdbcStatisticsConfig(),
session -> {
throw new UnsupportedOperationException();
},
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java
new file mode 100644
index 000000000000..d61c36705f5c
--- /dev/null
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatistics.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.mariadb;
+
+public class TestMariaDbTableIndexStatistics
+ extends BaseMariaDbTableIndexStatisticsTest
+{
+ public TestMariaDbTableIndexStatistics()
+ {
+ super(TestingMariaDbServer.DEFAULT_VERSION);
+ }
+}
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java
new file mode 100644
index 000000000000..3ebd8caff96e
--- /dev/null
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbTableIndexStatisticsLatest.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.plugin.mariadb;
+
+public class TestMariaDbTableIndexStatisticsLatest
+ extends BaseMariaDbTableIndexStatisticsTest
+{
+ public TestMariaDbTableIndexStatisticsLatest()
+ {
+ super(TestingMariaDbServer.LATEST_VERSION);
+ }
+}
diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java
index 7c6c80cf80a5..c9796befa167 100644
--- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java
+++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestingMariaDbServer.java
@@ -46,17 +46,19 @@ public TestingMariaDbServer(String tag)
// explicit-defaults-for-timestamp: 1 is ON, the default set is 0 (OFF)
container.withCommand("--character-set-server", "utf8mb4", "--explicit-defaults-for-timestamp=1");
container.start();
- execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()), "root", container.getPassword());
- }
- public void execute(String sql)
- {
- execute(sql, getUsername(), getPassword());
+ try (Connection connection = DriverManager.getConnection(getJdbcUrl(), "root", container.getPassword());
+ Statement statement = connection.createStatement()) {
+ statement.execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()));
+ }
+ catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
}
- private void execute(String sql, String user, String password)
+ public void execute(String sql)
{
- try (Connection connection = DriverManager.getConnection(getJdbcUrl(), user, password);
+ try (Connection connection = container.createConnection("");
Statement statement = connection.createStatement()) {
statement.execute(sql);
}
diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java
index 1d24da6fb2f6..3dce3fcd9108 100644
--- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java
+++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseTestMySqlTableStatisticsTest.java
@@ -20,8 +20,6 @@
import io.trino.testing.QueryRunner;
import io.trino.testing.sql.TestTable;
import org.assertj.core.api.AbstractDoubleAssert;
-import org.jdbi.v3.core.Handle;
-import org.jdbi.v3.core.Jdbi;
import org.testng.SkipException;
import org.testng.annotations.Test;
@@ -170,7 +168,7 @@ public void testAllNulls()
}
else {
assertNotNull(row.getField(2), "NDV for " + columnName);
- assertThat(((Number) row.getField(2)).doubleValue()).as("NDV for " + columnName).isBetween(0.0, 2.0);
+ assertThat((Double) row.getField(2)).as("NDV for " + columnName).isBetween(0.0, 2.0);
assertEquals(row.getField(3), nullFractionToExpected.apply(1.0), "null fraction for " + columnName);
}
@@ -347,10 +345,7 @@ public void testNumericCornerCases()
protected void executeInMysql(String sql)
{
- try (Handle handle = Jdbi.open(() -> mysqlServer.createConnection())) {
- handle.execute("USE tpch");
- handle.execute(sql);
- }
+ mysqlServer.execute(sql);
}
protected void assertColumnStats(MaterializedResult statsResult, Map columnNdvs)
@@ -430,7 +425,7 @@ protected static double getTableCardinalityFromStats(MaterializedResult statsRes
assertNull(lastRow.getField(6));
assertEquals(lastRow.getFieldCount(), 7);
assertNotNull(lastRow.getField(4));
- return ((Number) lastRow.getField(4)).doubleValue();
+ return (Double) lastRow.getField(4);
}
protected static class MapBuilder
diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java
index e579751c97cd..60a625d9949d 100644
--- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java
+++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlAutomaticJoinPushdown.java
@@ -18,8 +18,6 @@
import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest;
import io.trino.testing.MaterializedRow;
import io.trino.testing.QueryRunner;
-import org.jdbi.v3.core.Handle;
-import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.Test;
import static io.trino.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner;
@@ -73,9 +71,6 @@ protected void gatherStats(String tableName)
protected void onRemoteDatabase(String sql)
{
- try (Handle handle = Jdbi.open(() -> mySqlServer.createConnection())) {
- handle.execute("USE tpch");
- handle.execute(sql);
- }
+ mySqlServer.execute(sql);
}
}
diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java
index 48aec3b77fdf..af3adaa9ed7b 100644
--- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java
+++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestingMySqlServer.java
@@ -70,7 +70,14 @@ public TestingMySqlServer(String dockerImageName, boolean globalTransactionEnabl
this.container = container;
configureContainer(container);
cleanup = startOrReuse(container);
- execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()), "root", container.getPassword());
+
+ try (Connection connection = DriverManager.getConnection(getJdbcUrl(), "root", container.getPassword());
+ Statement statement = connection.createStatement()) {
+ statement.execute(format("GRANT ALL PRIVILEGES ON *.* TO '%s'", container.getUsername()));
+ }
+ catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
}
private void configureContainer(MySQLContainer> container)
@@ -79,20 +86,9 @@ private void configureContainer(MySQLContainer> container)
container.addParameter("TC_MY_CNF", null);
}
- public Connection createConnection()
- throws SQLException
- {
- return container.createConnection("");
- }
-
public void execute(String sql)
{
- execute(sql, getUsername(), getPassword());
- }
-
- public void execute(String sql, String user, String password)
- {
- try (Connection connection = DriverManager.getConnection(getJdbcUrl(), user, password);
+ try (Connection connection = createConnection();
Statement statement = connection.createStatement()) {
statement.execute(sql);
}
@@ -101,6 +97,12 @@ public void execute(String sql, String user, String password)
}
}
+ public Connection createConnection()
+ throws SQLException
+ {
+ return container.createConnection("");
+ }
+
public String getUsername()
{
return container.getUsername();