diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index 3130c47a65a4..2f65e127e5bc 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Splitter; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -24,7 +25,10 @@ import io.airlift.units.Duration; import io.trino.client.ClientSelectedRole; import io.trino.client.ClientSession; +import io.trino.client.JsonResponse; +import io.trino.client.ServerInfo; import io.trino.client.StatementClient; +import io.trino.client.TrinoJsonCodec; import jakarta.annotation.Nullable; import okhttp3.Call; import okhttp3.HttpUrl; @@ -74,16 +78,21 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Suppliers.memoize; import static com.google.common.base.Throwables.getCausalChain; import static com.google.common.collect.Maps.fromProperties; import static io.airlift.units.Duration.nanosSince; +import static io.trino.client.JsonResponse.execute; import static io.trino.client.StatementClientFactory.newStatementClient; +import static io.trino.client.TrinoJsonCodec.jsonCodec; import static io.trino.jdbc.ClientInfoProperty.APPLICATION_NAME; import static io.trino.jdbc.ClientInfoProperty.CLIENT_INFO; import static io.trino.jdbc.ClientInfoProperty.CLIENT_TAGS; import static io.trino.jdbc.ClientInfoProperty.TRACE_TOKEN; import static java.lang.String.format; import static java.net.HttpURLConnection.HTTP_BAD_METHOD; +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.nio.charset.StandardCharsets.US_ASCII; @@ -98,6 +107,7 @@ public class TrinoConnection { private static final Logger logger = Logger.getLogger(TrinoConnection.class.getPackage().getName()); + private static final TrinoJsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); private static final int CONNECTION_TIMEOUT_SECONDS = 30; // Not configurable private final AtomicBoolean closed = new AtomicBoolean(); @@ -136,6 +146,7 @@ public class TrinoConnection private boolean useExplicitPrepare = true; private boolean assumeNullCatalogMeansCurrentCatalog; private final boolean validateConnection; + private final Supplier> versionSupplier; TrinoConnection(TrinoDriverUri uri, Call.Factory httpCallFactory, Call.Factory segmentHttpCallFactory) throws SQLException @@ -186,6 +197,35 @@ public class TrinoConnection throw new SQLException("Unable to connect to Trino server", "08001", e); } } + this.versionSupplier = memoize(this::fetchVersionFromServerInfo); + } + + private Optional fetchVersionFromServerInfo() + { + HttpUrl url = HttpUrl.get(httpUri) + .newBuilder() + .encodedPath("/v1/info") + .build(); + + Request request = new Request.Builder() + .url(url) + .get() + .build(); + + Duration timeoutDuration = new Duration(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + long start = System.nanoTime(); + while (timeoutDuration.compareTo(nanosSince(start)) > 0) { + JsonResponse serverInfo = execute(SERVER_INFO_CODEC, httpCallFactory, request); + switch (serverInfo.getStatusCode()) { + case HTTP_OK: + return Optional.ofNullable(serverInfo.getValue().getNodeVersion().getVersion()); + case HTTP_NOT_FOUND: + case HTTP_FORBIDDEN: + return Optional.empty(); + } + } + + return Optional.empty(); } private boolean isConnectionValid(int timeout) @@ -237,6 +277,11 @@ private boolean isConnectionValid(int timeout) throw new IOException(format("Connection validation timed out after %ss", timeout), lastException); } + public Optional getServerVersion() + { + return versionSupplier.get(); + } + @Override public Statement createStatement() throws SQLException diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java index d0834d6c802d..5edb653e39cb 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDatabaseMetaData.java @@ -31,6 +31,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.Optional; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; @@ -140,6 +141,11 @@ public String getDatabaseProductName() public String getDatabaseProductVersion() throws SQLException { + Optional serverVersion = connection.getServerVersion(); + if (serverVersion.isPresent()) { + return serverVersion.orElseThrow(); + } + try (ResultSet rs = select("SELECT version()")) { rs.next(); return rs.getString(1);