diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 60088a667596..1f1d6f8efbe2 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -16,13 +16,11 @@ import com.google.common.base.Enums; import com.google.common.base.Joiner; import com.google.common.base.Throwables; -import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.microsoft.sqlserver.jdbc.SQLServerException; import io.airlift.log.Logger; import io.airlift.slice.Slice; -import io.trino.collect.cache.NonEvictableCache; import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; @@ -33,9 +31,7 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; -import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSortItem; -import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; @@ -102,7 +98,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; import java.util.stream.Stream; @@ -111,9 +106,7 @@ import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.MoreCollectors.toOptional; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; @@ -147,7 +140,6 @@ import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.plugin.sqlserver.SqlServerTableProperties.DATA_COMPRESSION; import static io.trino.plugin.sqlserver.SqlServerTableProperties.getDataCompression; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -181,7 +173,6 @@ import static java.lang.String.format; import static java.lang.String.join; import static java.math.RoundingMode.UNNECESSARY; -import static java.time.Duration.ofMinutes; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; @@ -199,12 +190,6 @@ public class SqlServerClient private static final Joiner DOT_JOINER = Joiner.on("."); - private final boolean snapshotIsolationDisabled; - private final NonEvictableCache snapshotIsolationEnabled = buildNonEvictableCache( - CacheBuilder.newBuilder() - .maximumSize(1) - .expireAfterWrite(ofMinutes(5))); - private final boolean statisticsEnabled; private final ConnectorExpressionRewriter connectorExpressionRewriter; @@ -215,7 +200,6 @@ public class SqlServerClient @Inject public SqlServerClient( BaseJdbcConfig config, - SqlServerConfig sqlServerConfig, JdbcStatisticsConfig statisticsConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, @@ -223,8 +207,6 @@ public SqlServerClient( { super(config, "\"", connectionFactory, queryBuilder, identifierMapping); - requireNonNull(sqlServerConfig, "sqlServerConfig is null"); - snapshotIsolationDisabled = sqlServerConfig.isSnapshotIsolationDisabled(); this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() @@ -844,7 +826,7 @@ public Map getTableProperties(ConnectorSession session, JdbcTabl if (!tableHandle.isNamedRelation()) { return ImmutableMap.of(); } - try (Connection connection = configureConnectionTransactionIsolation(connectionFactory.openConnection(session)); + try (Connection connection = connectionFactory.openConnection(session); Handle handle = Jdbi.open(connection)) { return getTableDataCompressionWithRetries(handle, tableHandle) .map(dataCompression -> ImmutableMap.of(DATA_COMPRESSION, dataCompression)) @@ -866,58 +848,6 @@ public void abortReadConnection(Connection connection, ResultSet resultSet) } } - @Override - public Connection getConnection(ConnectorSession session, JdbcOutputTableHandle handle) - throws SQLException - { - return configureConnectionTransactionIsolation(super.getConnection(session, handle)); - } - - @Override - public Connection getConnection(ConnectorSession session, JdbcSplit split) - throws SQLException - { - return configureConnectionTransactionIsolation(super.getConnection(session, split)); - } - - private Connection configureConnectionTransactionIsolation(Connection connection) - throws SQLException - { - if (snapshotIsolationDisabled) { - return connection; - } - try { - if (hasSnapshotIsolationEnabled(connection)) { - // SQL Server's READ COMMITTED + SNAPSHOT ISOLATION is equivalent to ordinary READ COMMITTED in e.g. Oracle, PostgreSQL. - connection.setTransactionIsolation(TRANSACTION_SNAPSHOT); - } - } - catch (SQLException e) { - connection.close(); - throw e; - } - - return connection; - } - - private boolean hasSnapshotIsolationEnabled(Connection connection) - throws SQLException - { - try { - return snapshotIsolationEnabled.get(SnapshotIsolationEnabledCacheKey.INSTANCE, () -> { - Handle handle = Jdbi.open(connection); - return handle.createQuery("SELECT snapshot_isolation_state FROM sys.databases WHERE name = :name") - .bind("name", connection.getCatalog()) - .mapTo(Boolean.class) - .findOne() - .orElse(false); - }); - } - catch (ExecutionException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, e); - } - } - private static String singleQuote(String... objects) { return singleQuote(DOT_JOINER.join(objects)); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java index 5b3a93eb0881..36845ae87abf 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java @@ -52,8 +52,11 @@ protected void setup(Binder binder) @Provides @Singleton @ForBaseJdbc - public ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider) + public static ConnectionFactory getConnectionFactory( + BaseJdbcConfig config, + SqlServerConfig sqlServerConfig, + CredentialProvider credentialProvider) { - return new DriverConnectionFactory(new SQLServerDriver(), config, credentialProvider); + return new SqlServerConnectionFactory(new DriverConnectionFactory(new SQLServerDriver(), config, credentialProvider), sqlServerConfig.isSnapshotIsolationDisabled()); } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerConnectionFactory.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerConnectionFactory.java new file mode 100644 index 000000000000..f098bb284e3f --- /dev/null +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerConnectionFactory.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import com.google.common.cache.CacheBuilder; +import io.trino.collect.cache.NonEvictableCache; +import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.concurrent.ExecutionException; + +import static com.microsoft.sqlserver.jdbc.ISQLServerConnection.TRANSACTION_SNAPSHOT; +import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.time.Duration.ofMinutes; +import static java.util.Objects.requireNonNull; + +public class SqlServerConnectionFactory + implements ConnectionFactory +{ + private final NonEvictableCache snapshotIsolationEnabled = buildNonEvictableCache( + CacheBuilder.newBuilder() + .maximumSize(1) + .expireAfterWrite(ofMinutes(5))); + + private final ConnectionFactory delegate; + private final boolean snapshotIsolationDisabled; + + public SqlServerConnectionFactory(ConnectionFactory delegate, boolean snapshotIsolationDisabled) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.snapshotIsolationDisabled = snapshotIsolationDisabled; + } + + @Override + public Connection openConnection(ConnectorSession session) + throws SQLException + { + Connection connection = delegate.openConnection(session); + try { + prepare(connection); + } + catch (SQLException e) { + try (Connection ignored = connection) { + throw e; + } + } + return connection; + } + + private void prepare(Connection connection) + throws SQLException + { + if (snapshotIsolationDisabled) { + return; + } + try { + if (hasSnapshotIsolationEnabled(connection)) { + // SQL Server's READ COMMITTED + SNAPSHOT ISOLATION is equivalent to ordinary READ COMMITTED in e.g. Oracle, PostgreSQL. + connection.setTransactionIsolation(TRANSACTION_SNAPSHOT); + } + } + catch (SQLException e) { + connection.close(); + throw e; + } + } + + private boolean hasSnapshotIsolationEnabled(Connection connection) + throws SQLException + { + try { + return snapshotIsolationEnabled.get(SnapshotIsolationEnabledCacheKey.INSTANCE, () -> { + Handle handle = Jdbi.open(connection); + return handle.createQuery("SELECT snapshot_isolation_state FROM sys.databases WHERE name = :name") + .bind("name", connection.getCatalog()) + .mapTo(Boolean.class) + .findOne() + .orElse(false); + }); + } + catch (ExecutionException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + + @Override + public void close() + throws SQLException + { + delegate.close(); + } + + private enum SnapshotIsolationEnabledCacheKey + { + // The snapshot isolation can be enabled or disabled on database level. We connect to single + // database, so from our perspective, this is a global property. + INSTANCE + } +} diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java index ac9ecb52b445..be18dc5faeff 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java @@ -59,7 +59,6 @@ public class TestSqlServerClient private static final JdbcClient JDBC_CLIENT = new SqlServerClient( new BaseJdbcConfig(), - new SqlServerConfig(), new JdbcStatisticsConfig(), session -> { throw new UnsupportedOperationException();