From 7e01309d2e1facc11c5a37df388a3c76318f0c27 Mon Sep 17 00:00:00 2001 From: Song Gao <39278329+xsgao-github@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:45:37 -0500 Subject: [PATCH] Add support for validating JDBC connections --- .../client/uri/ConnectionProperties.java | 11 + .../io/trino/client/uri/PropertyName.java | 3 +- .../java/io/trino/client/uri/TrinoUri.java | 11 + .../io/trino/client/uri/TestTrinoUri.java | 11 + .../java/io/trino/jdbc/TrinoConnection.java | 106 +++++++++- .../io/trino/jdbc/TestJdbcConnection.java | 196 +++++++++++++++++- .../java/io/trino/jdbc/TestTrinoDriver.java | 7 + .../io/trino/jdbc/TestTrinoDriverUri.java | 12 ++ .../dispatcher/QueuedStatementResource.java | 9 + docs/src/main/sphinx/client/jdbc.md | 4 +- 10 files changed, 365 insertions(+), 5 deletions(-) diff --git a/client/trino-client/src/main/java/io/trino/client/uri/ConnectionProperties.java b/client/trino-client/src/main/java/io/trino/client/uri/ConnectionProperties.java index cbabdc63864..aaa610c1210 100644 --- a/client/trino-client/src/main/java/io/trino/client/uri/ConnectionProperties.java +++ b/client/trino-client/src/main/java/io/trino/client/uri/ConnectionProperties.java @@ -115,6 +115,7 @@ enum SslVerificationMode public static final ConnectionProperty HTTP_LOGGING_LEVEL = new HttpLoggingLevel(); public static final ConnectionProperty> RESOURCE_ESTIMATES = new ResourceEstimates(); public static final ConnectionProperty> SQL_PATH = new SqlPath(); + public static final ConnectionProperty VALIDATE_CONNECTION = new ValidateConnection(); private static final Set> ALL_PROPERTIES = ImmutableSet.>builder() // Keep sorted @@ -172,6 +173,7 @@ enum SslVerificationMode .add(TIMEZONE) .add(TRACE_TOKEN) .add(USER) + .add(VALIDATE_CONNECTION) .build(); private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream() @@ -590,6 +592,15 @@ public KerberosRemoteServiceName() } } + private static class ValidateConnection + extends AbstractConnectionProperty + { + public ValidateConnection() + { + super(PropertyName.VALIDATE_CONNECTION, NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER); + } + } + private static Predicate isKerberosEnabled() { return properties -> KERBEROS_REMOTE_SERVICE_NAME.getValue(properties).isPresent(); diff --git a/client/trino-client/src/main/java/io/trino/client/uri/PropertyName.java b/client/trino-client/src/main/java/io/trino/client/uri/PropertyName.java index c22184ecf9c..b0f13b3ac09 100644 --- a/client/trino-client/src/main/java/io/trino/client/uri/PropertyName.java +++ b/client/trino-client/src/main/java/io/trino/client/uri/PropertyName.java @@ -75,7 +75,8 @@ public enum PropertyName TIMEOUT("timeout"), TIMEZONE("timezone"), TRACE_TOKEN("traceToken"), - USER("user"); + USER("user"), + VALIDATE_CONNECTION("validateConnection"); private final String key; diff --git a/client/trino-client/src/main/java/io/trino/client/uri/TrinoUri.java b/client/trino-client/src/main/java/io/trino/client/uri/TrinoUri.java index d844262cfa8..bbf6c6ee34d 100644 --- a/client/trino-client/src/main/java/io/trino/client/uri/TrinoUri.java +++ b/client/trino-client/src/main/java/io/trino/client/uri/TrinoUri.java @@ -98,6 +98,7 @@ import static io.trino.client.uri.ConnectionProperties.TIMEZONE; import static io.trino.client.uri.ConnectionProperties.TRACE_TOKEN; import static io.trino.client.uri.ConnectionProperties.USER; +import static io.trino.client.uri.ConnectionProperties.VALIDATE_CONNECTION; import static io.trino.client.uri.LoggingLevel.NONE; import static java.lang.String.CASE_INSENSITIVE_ORDER; import static java.lang.String.format; @@ -461,6 +462,11 @@ public LoggingLevel getHttpLoggingLevel() return resolveWithDefault(HTTP_LOGGING_LEVEL, NONE); } + public boolean isValidateConnection() + { + return resolveWithDefault(VALIDATE_CONNECTION, false); + } + private Map getResourceEstimates() { return resolveWithDefault(RESOURCE_ESTIMATES, ImmutableMap.of()); @@ -1047,6 +1053,11 @@ public Builder setPath(List path) return setProperty(SQL_PATH, requireNonNull(path, "path is null")); } + public Builder setValidateConnection(boolean value) + { + return setProperty(VALIDATE_CONNECTION, value); + } + Builder setProperty(ConnectionProperty connectionProperty, T value) { properties.put(connectionProperty.getKey(), connectionProperty.encodeValue(value)); diff --git a/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java b/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java index 41edea6c9f0..bb4e5a63cc1 100644 --- a/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java +++ b/client/trino-client/src/test/java/io/trino/client/uri/TestTrinoUri.java @@ -495,6 +495,17 @@ public void testDefaultPorts() assertThat(secureUri.getHttpUri()).isEqualTo(URI.create("https://localhost:443")); } + @Test + public void testValidateConnection() + { + TrinoUri uri = createTrinoUri("trino://localhost:8080"); + assertThat(uri.isValidateConnection()).isFalse(); + uri = createTrinoUri("trino://localhost:8080?validateConnection=true"); + assertThat(uri.isValidateConnection()).isTrue(); + uri = createTrinoUri("trino://localhost:8080?validateConnection=false"); + assertThat(uri.isValidateConnection()).isFalse(); + } + private static boolean isBuilderHelperMethod(String name) { if (name.equals("setSslVerificationNone")) { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index d2ddf7df4f4..101f066a62e 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -26,7 +26,13 @@ import io.trino.client.StatementClient; import jakarta.annotation.Nullable; import okhttp3.Call; +import okhttp3.HttpUrl; +import okhttp3.Request; +import okhttp3.Response; +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.ProtocolException; import java.net.URI; import java.nio.charset.CharsetEncoder; import java.sql.Array; @@ -56,6 +62,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -66,13 +73,18 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Throwables.getCausalChain; import static com.google.common.collect.Maps.fromProperties; +import static io.airlift.units.Duration.nanosSince; import static io.trino.client.StatementClientFactory.newStatementClient; import static io.trino.jdbc.ClientInfoProperty.APPLICATION_NAME; import static io.trino.jdbc.ClientInfoProperty.CLIENT_INFO; import static io.trino.jdbc.ClientInfoProperty.CLIENT_TAGS; import static io.trino.jdbc.ClientInfoProperty.TRACE_TOKEN; import static java.lang.String.format; +import static java.net.HttpURLConnection.HTTP_BAD_METHOD; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.util.Collections.newSetFromMap; import static java.util.Objects.requireNonNull; @@ -85,6 +97,8 @@ public class TrinoConnection { private static final Logger logger = Logger.getLogger(TrinoConnection.class.getPackage().getName()); + private static final int CONNECTION_TIMEOUT_SECONDS = 30; // Not configurable + private final AtomicBoolean closed = new AtomicBoolean(); private final AtomicBoolean autoCommit = new AtomicBoolean(true); private final AtomicInteger isolationLevel = new AtomicInteger(TRANSACTION_READ_UNCOMMITTED); @@ -119,8 +133,10 @@ public class TrinoConnection private final Set statements = newSetFromMap(new ConcurrentHashMap<>()); private boolean useExplicitPrepare = true; private boolean assumeNullCatalogMeansCurrentCatalog; + private final boolean validateConnection; TrinoConnection(TrinoDriverUri uri, Call.Factory httpCallFactory, Call.Factory segmentHttpCallFactory) + throws SQLException { requireNonNull(uri, "uri is null"); this.jdbcUri = uri.getUri(); @@ -156,6 +172,67 @@ public class TrinoConnection uri.getExplicitPrepare().ifPresent(value -> this.useExplicitPrepare = value); uri.getAssumeNullCatalogMeansCurrentCatalog().ifPresent(value -> this.assumeNullCatalogMeansCurrentCatalog = value); + + this.validateConnection = uri.isValidateConnection(); + if (validateConnection) { + try { + if (!isConnectionValid(CONNECTION_TIMEOUT_SECONDS)) { + throw new SQLException("Invalid authentication to Trino server", "28000"); + } + } + catch (UnsupportedOperationException | IOException e) { + throw new SQLException("Unable to connect to Trino server", "08001", e); + } + } + } + + private boolean isConnectionValid(int timeout) + throws IOException, UnsupportedOperationException + { + HttpUrl url = HttpUrl.get(httpUri) + .newBuilder() + .encodedPath("/v1/statement") + .build(); + + Request headRequest = new Request.Builder() + .url(url) + .head() + .build(); + + Exception lastException = null; + Duration timeoutDuration = new Duration(timeout, TimeUnit.SECONDS); + long start = System.nanoTime(); + + while (timeoutDuration.compareTo(nanosSince(start)) > 0) { + try (Response response = httpCallFactory.newCall(headRequest).execute()) { + switch (response.code()) { + case HTTP_OK: + return true; + case HTTP_UNAUTHORIZED: + return false; + case HTTP_BAD_METHOD: + throw new UnsupportedOperationException("Trino server does not support HEAD /v1/statement"); + } + + try { + MILLISECONDS.sleep(250); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return false; + } + } + catch (IOException e) { + if (getCausalChain(e).stream().anyMatch(TrinoConnection::isTransientConnectionValidationException)) { + lastException = e; + } + else { + throw e; + } + } + } + + throw new IOException(format("Connection validation timed out after %ss", timeout), lastException); } @Override @@ -528,7 +605,26 @@ public boolean isValid(int timeout) if (timeout < 0) { throw new SQLException("Timeout is negative"); } - return !isClosed(); + + if (isClosed()) { + return false; + } + + if (!validateConnection) { + return true; + } + + try { + return isConnectionValid(timeout); + } + catch (UnsupportedOperationException e) { + logger.log(Level.FINE, "Trino server does not support connection validation", e); + return false; + } + catch (IOException e) { + logger.log(Level.FINE, "Connection validation has failed", e); + return false; + } } @Override @@ -901,6 +997,14 @@ else if (applicationName != null) { return source; } + private static boolean isTransientConnectionValidationException(Throwable e) + { + if (e instanceof InterruptedIOException && e.getMessage().equals("timeout")) { + return true; + } + return e instanceof ProtocolException; + } + private static final class SqlExceptionHolder { @Nullable diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java index 4e97d030fee..6047c050895 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java @@ -23,6 +23,7 @@ import io.trino.client.ClientSelectedRole; import io.trino.plugin.blackhole.BlackHolePlugin; import io.trino.plugin.hive.HivePlugin; +import io.trino.server.security.PasswordAuthenticatorManager; import io.trino.server.testing.TestingTrinoServer; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; @@ -32,6 +33,8 @@ import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SystemTable; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.security.AccessDeniedException; +import io.trino.spi.security.BasicPrincipal; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; @@ -40,6 +43,12 @@ import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.parallel.Execution; +import java.io.File; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.Key; +import java.security.Principal; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; @@ -47,28 +56,39 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; +import java.util.Date; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Properties; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.io.Files.asCharSource; +import static com.google.common.io.Resources.getResource; import static com.google.inject.multibindings.Multibinder.newSetBinder; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Closeables.closeAll; +import static io.jsonwebtoken.security.Keys.hmacShaKeyFor; import static io.trino.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder; +import static io.trino.server.security.jwt.JwtUtil.newJwtBuilder; import static io.trino.spi.connector.SystemTable.Distribution.ALL_NODES; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.testing.assertions.Assert.assertEventually; import static java.lang.String.format; +import static java.lang.System.currentTimeMillis; +import static java.nio.charset.StandardCharsets.US_ASCII; import static java.sql.Types.VARCHAR; +import static java.util.Base64.getMimeDecoder; import static java.util.UUID.randomUUID; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.catchThrowableOfType; import static org.assertj.core.api.Fail.fail; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -77,18 +97,47 @@ @Execution(CONCURRENT) public class TestJdbcConnection { + private static final String TEST_USER = "admin"; + private static final String TEST_PASSWORD = "password"; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getName())); private TestingTrinoServer server; + private Key defaultKey; + private String sslTrustStorePath; @BeforeAll public void setupServer() throws Exception { Logging.initialize(); + + Path passwordConfigDummy = Files.createTempFile("passwordConfigDummy", ""); + passwordConfigDummy.toFile().deleteOnExit(); + + URL resource = getClass().getClassLoader().getResource("33.privateKey"); + assertThat(resource) + .describedAs("key directory not found") + .isNotNull(); + File keyDir = new File(resource.toURI()).getAbsoluteFile().getParentFile(); + + defaultKey = hmacShaKeyFor(getMimeDecoder().decode(asCharSource(new File(keyDir, "default-key.key"), US_ASCII).read().getBytes(US_ASCII))); + sslTrustStorePath = new File(getResource("localhost.truststore").toURI()).getPath(); + Module systemTables = binder -> newSetBinder(binder, SystemTable.class) .addBinding().to(ExtraCredentialsSystemTable.class).in(Scopes.SINGLETON); + server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .put("http-server.authentication.type", "PASSWORD,JWT") + .put("password-authenticator.config-files", passwordConfigDummy.toString()) + .put("http-server.authentication.allow-insecure-over-http", "false") + .put("http-server.process-forwarded", "true") + .put("http-server.authentication.jwt.key-file", new File(keyDir, "${KID}.key").getPath()) + .put("http-server.https.enabled", "true") + .put("http-server.https.keystore.path", new File(getResource("localhost.keystore").toURI()).getPath()) + .put("http-server.https.keystore.key", "changeit") + .buildOrThrow()) .setAdditionalModule(systemTables) .build(); server.installPlugin(new HivePlugin()); @@ -98,6 +147,7 @@ public void setupServer() .put("hive.security", "sql-standard") .put("fs.hadoop.enabled", "true") .buildOrThrow()); + server.getInstance(com.google.inject.Key.get(PasswordAuthenticatorManager.class)).setAuthenticators(TestJdbcConnection::authenticate); server.installPlugin(new BlackHolePlugin()); server.createCatalog("blackhole", "blackhole", ImmutableMap.of()); @@ -115,6 +165,14 @@ public void setupServer() } } + private static Principal authenticate(String user, String password) + { + if ((TEST_USER.equals(user) && TEST_PASSWORD.equals(password))) { + return new BasicPrincipal(user); + } + throw new AccessDeniedException("Invalid credentials"); + } + @AfterAll public void tearDown() throws Exception @@ -482,6 +540,96 @@ public void testConcurrentCancellationOnConnectionClose() testConcurrentCancellationOnConnectionClose(false); } + @Test + public void testConnectionValidation() + throws Exception + { + String validAccessToken = newJwtBuilder() + .subject("test") + .signWith(defaultKey) + .compact(); + + try (Connection conn = createConnectionUsingAccessToken(validAccessToken, "validateConnection=true")) { + assertThat(conn.isValid(10)).isTrue(); + } + + long delay = 50; + long expirationTime = 0; + long timeout = currentTimeMillis() + 60 * 1000; + + Connection acquiredConnection = null; + while (currentTimeMillis() < timeout) { + expirationTime = currentTimeMillis() + delay; + String expiringAccessToken = newJwtBuilder() + .subject("test") + .expiration(new Date(expirationTime)) + .signWith(defaultKey) + .compact(); + + try (Connection newConnection = createConnectionUsingAccessToken(expiringAccessToken, "validateConnection=true")) { + acquiredConnection = newConnection; + break; + } + catch (SQLException e) { + // Connection failure because the token is about to expire + } + + delay *= 2; + } + + assertThat(acquiredConnection).isNotNull(); + + try { + while (currentTimeMillis() < expirationTime) { + try { + Thread.sleep(100); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + assertThat(acquiredConnection.isValid(10)).isFalse(); + } + finally { + acquiredConnection.close(); + } + + try (Connection conn = createConnectionUsingAccessToken(validAccessToken, "")) { + assertThat(conn.isValid(10)).isTrue(); + } + + // With an expired token, isValid returns true if validateConnection is not enabled + try (Connection conn = createConnectionUsingAccessToken(validAccessToken, "validateConnection=false");) { + assertThat(conn.isValid(10)).isTrue(); + } + } + + @Test + public void testValidateConnection() + { + // Invalid host + assertThatCode(() -> createConnectionUsingInvalidHost("")) + .doesNotThrowAnyException(); + + SQLException e = catchThrowableOfType(() -> createConnectionUsingInvalidHost("validateConnection=true"), + SQLException.class); + assertThat(e.getSQLState().equals("08001")).isTrue(); + + assertThatCode(() -> createConnectionUsingInvalidHost("validateConnection=false")) + .doesNotThrowAnyException(); + + // Invalid password + assertThatCode(() -> createConnectionUsingInvalidPassword("")) + .doesNotThrowAnyException(); + + e = catchThrowableOfType(() -> createConnectionUsingInvalidPassword("validateConnection=true"), + SQLException.class); + assertThat(e.getSQLState().equals("28000")).isTrue(); + + assertThatCode(() -> createConnectionUsingInvalidPassword("validateConnection=false")) + .doesNotThrowAnyException(); + } + private void testConcurrentCancellationOnConnectionClose(boolean autoCommit) throws Exception { @@ -528,8 +676,52 @@ private Connection createConnection() private Connection createConnection(String extra) throws SQLException { - String url = format("jdbc:trino://%s/hive/default?%s", server.getAddress(), extra); - return DriverManager.getConnection(url, "admin", null); + String url = format("jdbc:trino://localhost:%s/hive/default?%s", server.getHttpsAddress().getPort(), extra); + Properties properties = new Properties(); + properties.put("user", TEST_USER); + properties.put("password", TEST_PASSWORD); + properties.setProperty("SSL", "true"); + properties.setProperty("SSLTrustStorePath", sslTrustStorePath); + properties.setProperty("SSLTrustStorePassword", "changeit"); + return DriverManager.getConnection(url, properties); + } + + private Connection createConnectionUsingInvalidHost(String extra) + throws SQLException + { + String url = format("jdbc:trino://invalidhost:%s/hive/default?%s", server.getHttpsAddress().getPort(), extra); + Properties properties = new Properties(); + properties.put("user", TEST_USER); + properties.put("password", TEST_PASSWORD); + properties.setProperty("SSL", "true"); + properties.setProperty("SSLTrustStorePath", sslTrustStorePath); + properties.setProperty("SSLTrustStorePassword", "changeit"); + return DriverManager.getConnection(url, properties); + } + + private Connection createConnectionUsingInvalidPassword(String extra) + throws SQLException + { + String url = format("jdbc:trino://localhost:%s/hive/default?%s", server.getHttpsAddress().getPort(), extra); + Properties properties = new Properties(); + properties.put("user", TEST_USER); + properties.put("password", "invalid_" + TEST_PASSWORD); + properties.setProperty("SSL", "true"); + properties.setProperty("SSLTrustStorePath", sslTrustStorePath); + properties.setProperty("SSLTrustStorePassword", "changeit"); + return DriverManager.getConnection(url, properties); + } + + private Connection createConnectionUsingAccessToken(String accessToken, String extra) + throws SQLException + { + String url = format("jdbc:trino://localhost:%s/hive/default?%s", server.getHttpsAddress().getPort(), extra); + Properties properties = new Properties(); + properties.put("accessToken", accessToken); + properties.setProperty("SSL", "true"); + properties.setProperty("SSLTrustStorePath", sslTrustStorePath); + properties.setProperty("SSLTrustStorePassword", "changeit"); + return DriverManager.getConnection(url, properties); } private static Set listTables(Connection connection) diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java index b34c562bfe9..c292b48a05d 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java @@ -870,6 +870,13 @@ public void testPropertyAllowed() .put("assumeLiteralUnderscoreInMetadataCallsForNonConformingClients", "true") .buildOrThrow()))) .isNotNull(); + + assertThat(DriverManager.getConnection(jdbcUrl(), + toProperties(ImmutableMap.builder() + .put("user", "test") + .put("validateConnection", "false") + .buildOrThrow()))) + .isNotNull(); } @Test diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverUri.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverUri.java index b2a4fc09a69..adf6a552dd6 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverUri.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriverUri.java @@ -479,6 +479,18 @@ public void testDefaultPorts() assertThat(secureUri.getHttpUri()).isEqualTo(URI.create("https://localhost:443")); } + @Test + public void testAValidateConnection() + throws SQLException + { + TrinoDriverUri uri = createDriverUri("jdbc:trino://localhost:8080"); + assertThat(uri.isValidateConnection()).isFalse(); + uri = createDriverUri("jdbc:trino://localhost:8080?validateConnection=true"); + assertThat(uri.isValidateConnection()).isTrue(); + uri = createDriverUri("jdbc:trino://localhost:8080?validateConnection=false"); + assertThat(uri.isValidateConnection()).isFalse(); + } + private static void assertUriPortScheme(TrinoDriverUri parameters, int port, String scheme) { URI uri = parameters.getHttpUri(); diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java index 1b10eb1f979..d4bafcebb86 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java @@ -54,6 +54,7 @@ import jakarta.ws.rs.DELETE; import jakarta.ws.rs.ForbiddenException; import jakarta.ws.rs.GET; +import jakarta.ws.rs.HEAD; import jakarta.ws.rs.NotFoundException; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; @@ -154,6 +155,14 @@ public void stop() queryManager.destroy(); } + @ResourceSecurity(AUTHENTICATED_USER) + @HEAD + @Produces(APPLICATION_JSON) + public Response validateConnection() + { + return Response.ok().build(); + } + @ResourceSecurity(AUTHENTICATED_USER) @POST @Produces(APPLICATION_JSON) diff --git a/docs/src/main/sphinx/client/jdbc.md b/docs/src/main/sphinx/client/jdbc.md index bcf7af7a018..29cddf87b1b 100644 --- a/docs/src/main/sphinx/client/jdbc.md +++ b/docs/src/main/sphinx/client/jdbc.md @@ -269,7 +269,9 @@ may not be specified using both methods. Valid values are JSON with Zstandard compression, `json+zstd` (recommended), JSON with LZ4 compression `json+lz4`, and uncompressed JSON `json`. By default, the default encoding configured on the cluster is used. - +* - `validateConnection` + - Defaults to `false`. If set to `true`, connectivity and credentials are validated + when the connection is created, and when `java.sql.Connection.isValid(int)` is called. ::: (jdbc-spooling-protocol)=