Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public ClientSession toClientSession()
{
return new ClientSession(
parseServer(server),
user.orElse(null),
user,
sessionUser,
source,
Optional.ofNullable(traceToken),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static ClientSession createClientSession(MockWebServer server)
{
return new ClientSession(
server.url("/").uri(),
"user",
Optional.of("user"),
Optional.empty(),
"source",
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
public class ClientSession
{
private final URI server;
private final String principal;
private final Optional<String> principal;
Comment thread
lukasz-walkiewicz marked this conversation as resolved.
Outdated
private final Optional<String> user;
private final String source;
private final Optional<String> traceToken;
Expand Down Expand Up @@ -68,7 +68,7 @@ public static ClientSession stripTransactionId(ClientSession session)

public ClientSession(
URI server,
String principal,
Optional<String> principal,
Optional<String> user,
String source,
Optional<String> traceToken,
Expand Down Expand Up @@ -143,7 +143,7 @@ public URI getServer()
return server;
}

public String getPrincipal()
public Optional<String> getPrincipal()
{
return principal;
}
Expand Down Expand Up @@ -270,7 +270,7 @@ public String toString()
public static final class Builder
{
private URI server;
private String principal;
private Optional<String> principal;
private Optional<String> user;
private String source;
private Optional<String> traceToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -84,7 +85,7 @@ class StatementClientV1
private final AtomicBoolean clearTransactionId = new AtomicBoolean();
private final ZoneId timeZone;
private final Duration requestTimeoutNanos;
private final String user;
private final Optional<String> user;
private final String clientCapabilities;
private final boolean compressionDisabled;

Expand All @@ -100,7 +101,10 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String
this.timeZone = session.getTimeZone();
this.query = query;
this.requestTimeoutNanos = session.getClientRequestTimeout();
this.user = session.getUser().orElse(session.getPrincipal());
this.user = Stream.of(session.getUser(), session.getPrincipal())
.filter(Optional::isPresent)
Comment thread
lukasz-walkiewicz marked this conversation as resolved.
Outdated
.map(Optional::get)
.findFirst();
this.clientCapabilities = Joiner.on(",").join(ClientCapabilities.values());
this.compressionDisabled = session.isCompressionDisabled();

Expand Down Expand Up @@ -310,9 +314,9 @@ public boolean isClearTransactionId()
private Request.Builder prepareRequest(HttpUrl url)
{
Request.Builder builder = new Request.Builder()
.addHeader(TRINO_HEADERS.requestUser(), user)
.addHeader(USER_AGENT, USER_AGENT_VALUE)
.url(url);
user.ifPresent(requestUser -> builder.addHeader(TRINO_HEADERS.requestUser(), requestUser));
if (compressionDisabled) {
builder.header(ACCEPT_ENCODING, "identity");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private static class User
{
public User()
{
super("user", REQUIRED, ALLOWED, NON_EMPTY_STRING_CONVERTER);
super("user", NOT_REQUIRED, ALLOWED, NON_EMPTY_STRING_CONVERTER);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public class TrinoConnection

private final URI jdbcUri;
private final URI httpUri;
private final String user;
private final Optional<String> user;
private final Optional<String> sessionUser;
private final boolean compressionDisabled;
private final boolean assumeLiteralNamesInMetadataCallsForNonConformingClients;
Expand Down Expand Up @@ -680,11 +680,6 @@ URI getURI()
return jdbcUri;
}

String getUser()
{
return user;
}

@VisibleForTesting
Map<String, String> getExtraCredentials()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ public String getURL()
public String getUserName()
throws SQLException
{
return connection.getUser();
try (ResultSet rs = select("SELECT current_user")) {
Comment thread
lukasz-walkiewicz marked this conversation as resolved.
Outdated
rs.next();
return rs.getString(1);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,18 @@ public URI getHttpUri()
return buildHttpUri();
}

public String getUser()
public String getRequiredUser()
throws SQLException
{
return USER.getRequiredValue(properties);
}

public Optional<String> getUser()
throws SQLException
{
return USER.getValue(properties);
}

public Optional<String> getSessionUser()
throws SQLException
{
Expand Down Expand Up @@ -258,7 +264,7 @@ public void setupClient(OkHttpClient.Builder builder)
if (!useSecureConnection) {
throw new SQLException("Authentication using username/password requires SSL to be enabled");
}
builder.addInterceptor(basicAuth(getUser(), password));
builder.addInterceptor(basicAuth(getRequiredUser(), password));
}

if (useSecureConnection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ private Connection createConnection()
{
String url = format("jdbc:trino://localhost:%s", server.getHttpsAddress().getPort());
Properties properties = new Properties();
properties.setProperty("user", "test");
properties.setProperty("SSL", "true");
properties.setProperty("SSLTrustStorePath", new File(getResource("localhost.truststore").toURI()).getPath());
properties.setProperty("SSLTrustStorePassword", "changeit");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ public void testGetDatabaseProductVersion()
}
}

@Test
public void testGetUserName()
throws Exception
{
try (Connection connection = createConnection()) {
DatabaseMetaData metaData = connection.getMetaData();
assertEquals(metaData.getUserName(), "admin");
}
}

@Test
public void testGetCatalogs()
throws Exception
Expand Down
18 changes: 3 additions & 15 deletions client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public void testDriverPropertyInfoEmpty()

assertThat(infos)
.extracting(TestTrinoDriver::driverPropertyInfoToString)
.contains("{name=user, required=true}")
.contains("{name=user, required=false}")
.contains("{name=password, required=false}")
.contains("{name=accessToken, required=false}")
.contains("{name=SSL, required=false, choices=[true, false]}");
Expand All @@ -413,7 +413,7 @@ public void testDriverPropertyInfoSslEnabled()

assertThat(infos)
.extracting(TestTrinoDriver::driverPropertyInfoToString)
.contains("{name=user, value=test, required=true}")
.contains("{name=user, value=test, required=false}")
.contains("{name=SSL, value=true, required=false, choices=[true, false]}")
.contains("{name=SSLVerification, required=false, choices=[FULL, CA, NONE]}")
.contains("{name=SSLTrustStorePath, required=false}");
Expand Down Expand Up @@ -773,23 +773,11 @@ public void testBadQuery()
}
}

@Test
public void testUserIsRequired()
{
assertThatThrownBy(() -> DriverManager.getConnection(jdbcUrl()))
.isInstanceOf(SQLException.class)
.hasMessage("Connection property 'user' is required");
}

@Test
public void testNullConnectProperties()
throws Exception
{
Driver driver = DriverManager.getDriver("jdbc:trino:");

assertThatThrownBy(() -> driver.connect(jdbcUrl(), null))
.isInstanceOf(SQLException.class)
.hasMessage("Connection property 'user' is required");
DriverManager.getDriver("jdbc:trino:").connect(jdbcUrl(), null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ private Connection createConnection(Map<String, String> additionalProperties)
{
String url = format("jdbc:trino://localhost:%s", server.getHttpsAddress().getPort());
Properties properties = new Properties();
properties.setProperty("user", "test");
properties.setProperty("SSL", "true");
properties.setProperty("SSLTrustStorePath", new File(getResource("localhost.truststore").toURI()).getPath());
properties.setProperty("SSLTrustStorePassword", "changeit");
Expand All @@ -341,7 +340,6 @@ private Connection createBasicConnection(Map<String, String> additionalPropertie
{
String url = format("jdbc:trino://localhost:%s", server.getHttpsAddress().getPort());
Properties properties = new Properties();
properties.setProperty("user", "test");
additionalProperties.forEach(properties::setProperty);
return DriverManager.getConnection(url, properties);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,6 @@ public void testInvalidUrls()
assertInvalid("jdbc:presto://localhost:8080", "Invalid JDBC URL: jdbc:presto://localhost:8080");
}

@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' is required")
public void testRequireUser()
throws Exception
{
TrinoDriverUri.create("jdbc:trino://localhost:8080", new Properties());
}

@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' value is empty")
public void testEmptyUser()
throws Exception
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
databases:
presto:
jdbc_url: "jdbc:trino://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema}?\
user=${databases.presto.cli_kerberos_principal}&\
SSL=true&\
SSLTrustStorePath=${databases.presto.https_keystore_path}&\
SSLTrustStorePassword=${databases.presto.https_keystore_password}&\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public void shouldAuthenticateAndExecuteQuery()
throws SQLException
{
Properties properties = new Properties();
properties.setProperty("user", "test");
String jdbcUrl = format("jdbc:trino://presto-master:7778?"
+ "SSL=true&"
+ "SSLTrustStorePath=%s&"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Properties;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.tempto.assertions.QueryAssert.Row.row;
Expand Down Expand Up @@ -98,9 +97,7 @@ public void shouldAuthenticateAndExecuteQuery()
throws Exception
{
prepareHandler();
Properties properties = new Properties();
properties.setProperty("user", "test");
try (Connection connection = DriverManager.getConnection(jdbcUrl, properties);
try (Connection connection = DriverManager.getConnection(jdbcUrl);
PreparedStatement statement = connection.prepareStatement("SELECT * FROM tpch.tiny.nation");
ResultSet results = statement.executeQuery()) {
assertThat(forResultSet(results)).matches(TpchTableResults.PRESTO_NATION_RESULT);
Expand All @@ -112,9 +109,7 @@ public void shouldAuthenticateAfterTokenExpires()
throws Exception
{
prepareHandler();
Properties properties = new Properties();
properties.setProperty("user", "test");
try (Connection connection = DriverManager.getConnection(jdbcUrl, properties);
try (Connection connection = DriverManager.getConnection(jdbcUrl);
PreparedStatement statement = connection.prepareStatement("SELECT * FROM tpch.tiny.nation");
ResultSet results = statement.executeQuery()) {
assertThat(forResultSet(results)).matches(TpchTableResults.PRESTO_NATION_RESULT);
Expand All @@ -133,9 +128,7 @@ public void shouldReturnGroups()
throws SQLException
{
prepareHandler();
Properties properties = new Properties();
properties.setProperty("user", "test");
try (Connection connection = DriverManager.getConnection(jdbcUrl, properties);
try (Connection connection = DriverManager.getConnection(jdbcUrl);
PreparedStatement statement = connection.prepareStatement("SELECT array_sort(current_groups())");
ResultSet rs = statement.executeQuery()) {
assertThat(forResultSet(rs)).containsOnly(row(ImmutableList.of("admin", "public")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private static ClientSession toClientSession(Session session, URI server, Durati

return new ClientSession(
server,
session.getIdentity().getUser(),
Optional.of(session.getIdentity().getUser()),
Optional.empty(),
session.getSource().orElse(null),
session.getTraceToken(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private static QueryId startQuery(String sql, DistributedQueryRunner queryRunner
try {
ClientSession clientSession = new ClientSession(
queryRunner.getCoordinator().getBaseUrl(),
"user",
Optional.of("user"),
Optional.empty(),
"source",
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private QueryError trySelectQuery(String assumedUser)
try {
ClientSession clientSession = new ClientSession(
getDistributedQueryRunner().getCoordinator().getBaseUrl(),
"user",
Optional.of("user"),
Optional.of(assumedUser),
"source",
Optional.empty(),
Expand Down