Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ abstract class AbstractTrinoResultSet
private final AtomicReference<List<Object>> row = new AtomicReference<>();
private final AtomicLong currentRowNumber = new AtomicLong(); // Index into 'rows' of our current row (1-based)
private final AtomicBoolean wasNull = new AtomicBoolean();
protected final AtomicBoolean closed = new AtomicBoolean();
private final Optional<Statement> statement;

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, Iterator<List<Object>> results)
Expand Down Expand Up @@ -1468,11 +1467,8 @@ public int getHoldability()
}

@Override
public boolean isClosed()
throws SQLException
{
return closed.get();
}
public abstract boolean isClosed()
throws SQLException;

@Override
public void updateNString(int columnIndex, String nString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Objects.requireNonNull;

class InMemoryTrinoResultSet
extends AbstractTrinoResultSet
{
private final AtomicBoolean closed = new AtomicBoolean();

public InMemoryTrinoResultSet(List<Column> columns, List<List<Object>> results)
{
super(Optional.empty(), columns, requireNonNull(results, "results is null").iterator());
Expand All @@ -32,5 +35,14 @@ public InMemoryTrinoResultSet(List<Column> columns, List<List<Object>> results)
@Override
public void close()
throws SQLException
{}
{
closed.set(true);
}

@Override
public boolean isClosed()
throws SQLException
{
return closed.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,12 @@ private void unregisterStatement(TrinoStatement statement)
checkState(statements.remove(statement), "Statement is not registered");
}

@VisibleForTesting
int activeStatements()
{
return statements.size();
}

private void checkOpen()
throws SQLException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.sql.RowIdLifetime;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -1488,7 +1489,24 @@ private ResultSet selectEmpty(String sql)
private ResultSet select(String sql)
throws SQLException
{
return getConnection().createStatement().executeQuery(sql);
Statement statement = getConnection().createStatement();
TrinoResultSet resultSet;
try {
resultSet = (TrinoResultSet) statement.executeQuery(sql);
resultSet.setCloseStatementOnClose();
}
catch (Throwable e) {
try {
Comment thread
findepi marked this conversation as resolved.
Outdated
statement.close();
}
catch (Throwable closeException) {
if (closeException != e) {
e.addSuppressed(closeException);
}
}
throw e;
}
return resultSet;
}

private static void buildFilters(StringBuilder out, List<String> filters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.trino.client.QueryStatusInfo;
import io.trino.client.StatementClient;

import javax.annotation.concurrent.GuardedBy;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Iterator;
Expand All @@ -43,9 +45,15 @@
public class TrinoResultSet
extends AbstractTrinoResultSet
{
private final Statement statement;
private final StatementClient client;
private final String queryId;

@GuardedBy("this")
private boolean closed;
@GuardedBy("this")
private boolean closeStatementOnClose;

static TrinoResultSet create(Statement statement, StatementClient client, long maxRows, Consumer<QueryStats> progressCallback, WarningsManager warningsManager)
throws SQLException
{
Expand All @@ -62,6 +70,7 @@ private TrinoResultSet(Statement statement, StatementClient client, List<Column>
columns,
new AsyncIterator<>(flatten(new ResultsPageIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager), maxRows), client));

this.statement = statement;
this.client = requireNonNull(client, "client is null");
requireNonNull(progressCallback, "progressCallback is null");

Expand All @@ -78,13 +87,46 @@ public QueryStats getStats()
return QueryStats.create(queryId, client.getStats());
}

void setCloseStatementOnClose()
throws SQLException
{
boolean alreadyClosed;
synchronized (this) {
alreadyClosed = closed;
if (!alreadyClosed) {
closeStatementOnClose = true;
}
}
if (alreadyClosed) {
statement.close();
}
}

@Override
public void close()
throws SQLException
{
closed.set(true);
boolean closeStatement;
synchronized (this) {
if (closed) {
return;
}
closed = true;
closeStatement = closeStatementOnClose;
}

((AsyncIterator<?>) results).cancel();
client.close();
if (closeStatement) {
statement.close();
}
}

@Override
public synchronized boolean isClosed()
throws SQLException
{
return closed;
}

void partialCancel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,33 @@ public void testEscapeIfNecessary()
assertEquals(TrinoDatabaseMetaData.escapeIfNecessary(true, "abc\\_def"), "abc\\\\\\_def");
}

@Test
public void testStatementsDoNotLeak()
throws Exception
{
TrinoConnection connection = (TrinoConnection) this.connection;
DatabaseMetaData metaData = connection.getMetaData();

// consumed
try (ResultSet resultSet = metaData.getCatalogs()) {
assertThat(countRows(resultSet)).isEqualTo(5);
}
try (ResultSet resultSet = metaData.getSchemas(TEST_CATALOG, null)) {
assertThat(countRows(resultSet)).isEqualTo(10);
}
try (ResultSet resultSet = metaData.getTables(TEST_CATALOG, "sf%", null, null)) {
assertThat(countRows(resultSet)).isEqualTo(64);
}

// not consumed
metaData.getCatalogs().close();
metaData.getSchemas(TEST_CATALOG, null).close();
metaData.getTables(TEST_CATALOG, "sf%", null, null).close();

assertThat(connection.activeStatements()).as("activeStatements")
.isEqualTo(0);
}

private static void assertColumnSpec(ResultSet rs, int dataType, Long precision, Long numPrecRadix, String typeName)
throws SQLException
{
Expand Down Expand Up @@ -1585,6 +1612,16 @@ private MetaDataCallback<List<List<Object>>> readMetaData(MetaDataCallback<Resul
};
}

private int countRows(ResultSet resultSet)
throws Exception
Comment thread
findepi marked this conversation as resolved.
Outdated
{
int rows = 0;
while (resultSet.next()) {
rows++;
}
return rows;
}

private Connection createConnection()
throws SQLException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
import java.util.Map;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.repeat;
import static com.google.common.io.ByteStreams.toByteArray;
import static com.google.common.io.MoreFiles.deleteRecursively;
import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE;
Expand Down Expand Up @@ -736,9 +735,9 @@ public void testStreamingUpload()
fs.setS3Client(s3);
try (FSDataOutputStream stream = fs.create(new Path("s3n://test-bucket/test"))) {
stream.write('a');
stream.write(repeat("foo", 2).getBytes(US_ASCII));
stream.write(repeat("bar", 3).getBytes(US_ASCII));
stream.write(repeat("orange", 4).getBytes(US_ASCII), 6, 12);
stream.write("foo".repeat(2).getBytes(US_ASCII));
stream.write("bar".repeat(3).getBytes(US_ASCII));
stream.write("orange".repeat(4).getBytes(US_ASCII), 6, 12);
}

List<UploadPartRequest> parts = s3.getUploadParts();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import java.util.Properties;
import java.util.function.Consumer;

import static com.google.common.base.Strings.repeat;
import static io.trino.JdbcDriverCapabilities.correctlyReportsTimestampWithTimeZone;
import static io.trino.JdbcDriverCapabilities.driverVersion;
import static io.trino.JdbcDriverCapabilities.hasBrokenParametricTimestampWithTimeZoneSupport;
Expand Down Expand Up @@ -112,7 +111,7 @@ public void tearDown()
public void testLongPreparedStatement()
throws Exception
{
String sql = format("SELECT '%s' = '%s'", repeat("x", 100_000), repeat("y", 100_000));
String sql = format("SELECT '%s' = '%s'", "x".repeat(100_000), "y".repeat(100_000));

try (ResultSet rs = runQuery(sql)) {
assertThat(rs.next()).isTrue();
Expand Down