From 725b04b85399a08ead4827d55494bf05a97200a9 Mon Sep 17 00:00:00 2001 From: Alex Albu Date: Tue, 1 Aug 2023 21:37:03 -0400 Subject: [PATCH] Add a way to set the session user on the JDBC connection after creation --- .../java/io/trino/jdbc/TrinoConnection.java | 17 +++++++++--- .../io/trino/jdbc/TestJdbcConnection.java | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index ebae2f98a9eb..bedcf5e1c832 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -93,11 +93,11 @@ public class TrinoConnection private final AtomicReference locale = new AtomicReference<>(); private final AtomicReference networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); private final AtomicLong nextStatementId = new AtomicLong(1); + private final AtomicReference> sessionUser = new AtomicReference<>(); private final URI jdbcUri; private final URI httpUri; private final Optional user; - private final Optional sessionUser; private final boolean compressionDisabled; private final boolean assumeLiteralNamesInMetadataCallsForNonConformingClients; private final boolean assumeLiteralUnderscoreInMetadataCallsForNonConformingClients; @@ -120,7 +120,7 @@ public class TrinoConnection uri.getSchema().ifPresent(schema::set); uri.getCatalog().ifPresent(catalog::set); this.user = uri.getUser(); - this.sessionUser = uri.getSessionUser(); + this.sessionUser.set(uri.getSessionUser()); this.applicationNamePrefix = uri.getApplicationNamePrefix(); this.source = uri.getSource(); this.extraCredentials = uri.getExtraCredentials(); @@ -636,6 +636,17 @@ public void setSessionProperty(String name, String value) sessionProperties.put(name, value); } + public void setSessionUser(String sessionUser) + { + requireNonNull(sessionUser, "sessionUser is null"); + this.sessionUser.set(Optional.of(sessionUser)); + } + + public void clearSessionUser() + { + this.sessionUser.set(Optional.empty()); + } + @VisibleForTesting Map getRoles() { @@ -734,7 +745,7 @@ StatementClient startQuery(String sql, Map sessionPropertiesOver ClientSession session = ClientSession.builder() .server(httpUri) .principal(user) - .user(sessionUser) + .user(sessionUser.get()) .source(source) .traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN))) .clientTags(ImmutableSet.copyOf(clientTags)) diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java index 4953e52293d1..a01df9ca73c1 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java @@ -431,6 +431,21 @@ public void testSessionProperties() } } + @Test + public void testSessionUser() + throws SQLException + { + try (Connection connection = createConnection()) { + assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo("admin"); + TrinoConnection trinoConnection = connection.unwrap(TrinoConnection.class); + String impersonatedUser = "alice"; + trinoConnection.setSessionUser(impersonatedUser); + assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo(impersonatedUser); + trinoConnection.clearSessionUser(); + assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo("admin"); + } + } + /** * @see TestJdbcStatement#testCancellationOnStatementClose() * @see TestJdbcStatement#testConcurrentCancellationOnStatementClose() @@ -570,6 +585,18 @@ private List listSingleStringColumn(String sql) return statuses.build(); } + private String getSingleStringColumn(Connection connection, String sql) + throws SQLException + { + try (Statement statement = connection.createStatement(); ResultSet resultSet = statement.executeQuery(sql)) { + assertThat(resultSet.getMetaData().getColumnCount()).isOne(); + assertThat(resultSet.next()).isTrue(); + String result = resultSet.getString(1); + assertThat(resultSet.next()).isFalse(); + return result; + } + } + private static void assertConnectionSource(Connection connection, String expectedSource) throws SQLException {