Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ enum SslVerificationMode
public static final ConnectionProperty<String, String> DNS_RESOLVER_CONTEXT = new ResolverContext();
public static final ConnectionProperty<String, String> HOSTNAME_IN_CERTIFICATE = new HostnameInCertificate();
public static final ConnectionProperty<String, ZoneId> TIMEZONE = new TimeZone();
public static final ConnectionProperty<String, Boolean> LEGACY_PREPARED_STATEMENTS = new LegacyPreparedStatements();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @wendigo since you were recently unifying TrinoUri and ConnectionProperties.


private static final Set<ConnectionProperty<?, ?>> ALL_PROPERTIES = ImmutableSet.<ConnectionProperty<?, ?>>builder()
.add(USER)
Expand Down Expand Up @@ -144,6 +145,7 @@ enum SslVerificationMode
.add(DNS_RESOLVER_CONTEXT)
.add(HOSTNAME_IN_CERTIFICATE)
.add(TIMEZONE)
.add(LEGACY_PREPARED_STATEMENTS)
.build();

private static final Map<String, ConnectionProperty<?, ?>> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream()
Expand Down Expand Up @@ -716,6 +718,15 @@ public TimeZone()
}
}

private static class LegacyPreparedStatements
extends AbstractConnectionProperty<String, Boolean>
{
public LegacyPreparedStatements()
{
super(PropertyName.LEGACY_PREPARED_STATEMENTS, NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER);
}
}

private static class MapPropertyParser
{
private static final CharMatcher PRINTABLE_ASCII = CharMatcher.inRange((char) 0x21, (char) 0x7E);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public enum PropertyName
TRACE_TOKEN("traceToken"),
SESSION_PROPERTIES("sessionProperties"),
SOURCE("source"),
LEGACY_PREPARED_STATEMENTS("legacyPreparedStatements"),
DNS_RESOLVER("dnsResolver"),
DNS_RESOLVER_CONTEXT("dnsResolverContext"),
HOSTNAME_IN_CERTIFICATE("hostnameInCertificate"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import static io.trino.client.uri.ConnectionProperties.KERBEROS_REMOTE_SERVICE_NAME;
import static io.trino.client.uri.ConnectionProperties.KERBEROS_SERVICE_PRINCIPAL_PATTERN;
import static io.trino.client.uri.ConnectionProperties.KERBEROS_USE_CANONICAL_HOSTNAME;
import static io.trino.client.uri.ConnectionProperties.LEGACY_PREPARED_STATEMENTS;
import static io.trino.client.uri.ConnectionProperties.PASSWORD;
import static io.trino.client.uri.ConnectionProperties.ROLES;
import static io.trino.client.uri.ConnectionProperties.SESSION_PROPERTIES;
Expand Down Expand Up @@ -167,6 +168,7 @@ public class TrinoUri
private Optional<String> traceToken;
private Optional<Map<String, String>> sessionProperties;
private Optional<String> source;
private Optional<Boolean> legacyPreparedStatements;

private Optional<String> catalog = Optional.empty();
private Optional<String> schema = Optional.empty();
Expand Down Expand Up @@ -219,7 +221,8 @@ private TrinoUri(
Optional<String> clientTags,
Optional<String> traceToken,
Optional<Map<String, String>> sessionProperties,
Optional<String> source)
Optional<String> source,
Optional<Boolean> legacyPreparedStatements)
throws SQLException
{
this.uri = requireNonNull(uri, "uri is null");
Expand Down Expand Up @@ -272,6 +275,7 @@ private TrinoUri(
this.traceToken = TRACE_TOKEN.getValueOrDefault(urlProperties, traceToken);
this.sessionProperties = SESSION_PROPERTIES.getValueOrDefault(urlProperties, sessionProperties);
this.source = SOURCE.getValueOrDefault(urlProperties, source);
this.legacyPreparedStatements = LEGACY_PREPARED_STATEMENTS.getValueOrDefault(urlProperties, legacyPreparedStatements);

properties = buildProperties();

Expand Down Expand Up @@ -357,6 +361,7 @@ private Properties buildProperties()
clientTags.ifPresent(value -> properties.setProperty(PropertyName.CLIENT_TAGS.toString(), value));
traceToken.ifPresent(value -> properties.setProperty(PropertyName.TRACE_TOKEN.toString(), value));
source.ifPresent(value -> properties.setProperty(PropertyName.SOURCE.toString(), value));
legacyPreparedStatements.ifPresent(value -> properties.setProperty(PropertyName.LEGACY_PREPARED_STATEMENTS.toString(), value.toString()));
return properties;
}

Expand Down Expand Up @@ -416,6 +421,7 @@ protected TrinoUri(URI uri, Properties driverProperties)
this.traceToken = TRACE_TOKEN.getValue(properties);
this.sessionProperties = SESSION_PROPERTIES.getValue(properties);
this.source = SOURCE.getValue(properties);
this.legacyPreparedStatements = LEGACY_PREPARED_STATEMENTS.getValue(properties);

// enable SSL by default for the trino schema and the standard port
useSecureConnection = ssl.orElse(uri.getScheme().equals("https") || (uri.getScheme().equals("trino") && uri.getPort() == 443));
Expand Down Expand Up @@ -523,6 +529,11 @@ public Optional<String> getSource()
return source;
}

public Optional<Boolean> getLegacyPreparedStatements()
{
return legacyPreparedStatements;
}

public boolean isCompressionDisabled()
{
return disableCompression.orElse(false);
Expand Down Expand Up @@ -927,6 +938,7 @@ public static final class Builder
private String traceToken;
private Map<String, String> sessionProperties;
private String source;
private Boolean legacyPreparedStatements;

private Builder() {}

Expand Down Expand Up @@ -1215,6 +1227,12 @@ public Builder setSource(String source)
return this;
}

public Builder setLegacyPreparedStatements(Boolean legacyPreparedStatements)
{
this.legacyPreparedStatements = requireNonNull(legacyPreparedStatements, "legacyPreparedStatements is null");
return this;
}

public TrinoUri build()
throws SQLException
{
Expand Down Expand Up @@ -1263,7 +1281,8 @@ public TrinoUri build()
Optional.ofNullable(clientTags),
Optional.ofNullable(traceToken),
Optional.ofNullable(sessionProperties),
Optional.ofNullable(source));
Optional.ofNullable(source),
Optional.ofNullable(legacyPreparedStatements));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public class TrinoConnection
private final AtomicReference<String> transactionId = new AtomicReference<>();
private final Call.Factory httpCallFactory;
private final Set<TrinoStatement> statements = newSetFromMap(new ConcurrentHashMap<>());
private boolean useLegacyPreparedStatements = true;

TrinoConnection(TrinoDriverUri uri, Call.Factory httpCallFactory)
{
Expand Down Expand Up @@ -144,6 +145,8 @@ public class TrinoConnection
timeZoneId.set(uri.getTimeZone());
locale.set(Locale.getDefault());
sessionProperties.putAll(uri.getSessionProperties());

uri.getLegacyPreparedStatements().ifPresent(value -> this.useLegacyPreparedStatements = value);
}

@Override
Expand Down Expand Up @@ -911,4 +914,9 @@ public void throwIfHeld()
}
}
}

public Boolean isUseLegacyPreparedStatements()
{
return this.useLegacyPreparedStatements;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,18 @@ public class TrinoPreparedStatement
private final String statementName;
private final String originalSql;
private boolean isBatch;
private boolean prepareStatementExecuted;

TrinoPreparedStatement(TrinoConnection connection, Consumer<TrinoStatement> onClose, String statementName, String sql)
throws SQLException
{
super(connection, onClose);
this.statementName = requireNonNull(statementName, "statementName is null");
this.originalSql = requireNonNull(sql, "sql is null");
super.execute(format("PREPARE %s FROM %s", statementName, sql));
if (connection().isUseLegacyPreparedStatements()) {
super.execute(format("PREPARE %s FROM %s", statementName, sql));
prepareStatementExecuted = true;
}
}

@Override
Expand Down Expand Up @@ -683,6 +687,8 @@ public void setArray(int parameterIndex, Array x)
public ResultSetMetaData getMetaData()
throws SQLException
{
prepareStatementIfNecessary();

try (Statement statement = connection().createStatement(); ResultSet resultSet = statement.executeQuery("DESCRIBE OUTPUT " + statementName)) {
return new TrinoResultSetMetaData(getDescribeOutputColumnInfoList(resultSet));
}
Expand Down Expand Up @@ -720,6 +726,8 @@ public void setURL(int parameterIndex, URL x)
public ParameterMetaData getParameterMetaData()
throws SQLException
{
prepareStatementIfNecessary();

try (Statement statement = connection().createStatement(); ResultSet resultSet = statement.executeQuery("DESCRIBE INPUT " + statementName)) {
return new TrinoParameterMetaData(getParamerters(resultSet));
}
Expand Down Expand Up @@ -986,7 +994,19 @@ private void requireNonBatchStatement()
}
}

private static String getExecuteSql(String statementName, List<String> values)
private String getExecuteImmediateSql(List<String> values)
{
StringBuilder sql = new StringBuilder();
sql.append("EXECUTE IMMEDIATE ");
sql.append(formatStringLiteral(originalSql));
if (!values.isEmpty()) {
sql.append(" USING ");
Joiner.on(", ").appendTo(sql, values);
}
return sql.toString();
}

private String getLegacySql(String statementName, List<String> values)
{
StringBuilder sql = new StringBuilder();
sql.append("EXECUTE ").append(statementName);
Expand All @@ -997,6 +1017,14 @@ private static String getExecuteSql(String statementName, List<String> values)
return sql.toString();
}

private String getExecuteSql(String statementName, List<String> values)
throws SQLException
{
return connection().isUseLegacyPreparedStatements()
? getLegacySql(statementName, values)
: getExecuteImmediateSql(values);
}

private static String formatLiteral(String type, String x)
{
return type + " " + formatStringLiteral(x);
Expand Down Expand Up @@ -1118,6 +1146,22 @@ private static List<ColumnInfo> getDescribeOutputColumnInfoList(ResultSet result
return list.build();
}

/*
When isUseLegacyPreparedStatements is disabled, the PREPARE statement won't be executed unless needed
e.g. when getMetadata() or getParameterMetadata() are called.
When needed, just make sure it is executed only once, even if the metadata methods are called many times
*/
private void prepareStatementIfNecessary()
throws SQLException
{
if (prepareStatementExecuted) {
return;
}

super.execute(format("PREPARE %s FROM %s", statementName, originalSql));
prepareStatementExecuted = true;
}

@VisibleForTesting
static ClientTypeSignature getClientTypeSignatureFromTypeString(String type)
{
Expand Down
Loading