Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ public class TrinoConnection
private final AtomicReference<Locale> locale = new AtomicReference<>();
private final AtomicReference<Integer> networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2)));
private final AtomicLong nextStatementId = new AtomicLong(1);
private final AtomicReference<Optional<String>> sessionUser = new AtomicReference<>();

private final URI jdbcUri;
private final URI httpUri;
private final Optional<String> user;
private final Optional<String> sessionUser;
private final boolean compressionDisabled;
private final boolean assumeLiteralNamesInMetadataCallsForNonConformingClients;
private final boolean assumeLiteralUnderscoreInMetadataCallsForNonConformingClients;
Expand All @@ -120,7 +120,7 @@ public class TrinoConnection
uri.getSchema().ifPresent(schema::set);
uri.getCatalog().ifPresent(catalog::set);
this.user = uri.getUser();
this.sessionUser = uri.getSessionUser();
this.sessionUser.set(uri.getSessionUser());
this.applicationNamePrefix = uri.getApplicationNamePrefix();
this.source = uri.getSource();
this.extraCredentials = uri.getExtraCredentials();
Expand Down Expand Up @@ -636,6 +636,17 @@ public void setSessionProperty(String name, String value)
sessionProperties.put(name, value);
}

public void setSessionUser(String sessionUser)
{
requireNonNull(sessionUser, "sessionUser is null");
this.sessionUser.set(Optional.of(sessionUser));
}

public void clearSessionUser()
{
this.sessionUser.set(Optional.empty());
}

@VisibleForTesting
Map<String, ClientSelectedRole> getRoles()
{
Expand Down Expand Up @@ -734,7 +745,7 @@ StatementClient startQuery(String sql, Map<String, String> sessionPropertiesOver
ClientSession session = ClientSession.builder()
.server(httpUri)
.principal(user)
.user(sessionUser)
.user(sessionUser.get())
.source(source)
.traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN)))
.clientTags(ImmutableSet.copyOf(clientTags))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,21 @@ public void testSessionProperties()
}
}

@Test
public void testSessionUser()
throws SQLException
{
try (Connection connection = createConnection()) {
assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo("admin");
TrinoConnection trinoConnection = connection.unwrap(TrinoConnection.class);
String impersonatedUser = "alice";
trinoConnection.setSessionUser(impersonatedUser);
assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo(impersonatedUser);
trinoConnection.clearSessionUser();
assertThat(getSingleStringColumn(connection, "select current_user")).isEqualTo("admin");
}
}

/**
* @see TestJdbcStatement#testCancellationOnStatementClose()
* @see TestJdbcStatement#testConcurrentCancellationOnStatementClose()
Expand Down Expand Up @@ -570,6 +585,18 @@ private List<String> listSingleStringColumn(String sql)
return statuses.build();
}

private String getSingleStringColumn(Connection connection, String sql)
throws SQLException
{
try (Statement statement = connection.createStatement(); ResultSet resultSet = statement.executeQuery(sql)) {
assertThat(resultSet.getMetaData().getColumnCount()).isOne();
assertThat(resultSet.next()).isTrue();
String result = resultSet.getString(1);
assertThat(resultSet.next()).isFalse();
return result;
}
}

private static void assertConnectionSource(Connection connection, String expectedSource)
throws SQLException
{
Expand Down