diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java index 35fac6a3e48d..d36ad5194040 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java @@ -13,6 +13,7 @@ */ package io.trino.jdbc; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Streams; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -29,9 +30,9 @@ import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.function.Consumer; import java.util.stream.Stream; @@ -144,7 +145,8 @@ private static Iterator flatten(Iterator> iterator, long maxR return stream.iterator(); } - private static class AsyncIterator + @VisibleForTesting + static class AsyncIterator extends AbstractIterator { private static final int MAX_QUEUED_ROWS = 50_000; @@ -152,19 +154,30 @@ private static class AsyncIterator new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build()); private final StatementClient client; - private final BlockingQueue rowQueue = new ArrayBlockingQueue<>(MAX_QUEUED_ROWS); + private final BlockingQueue rowQueue; // Semaphore to indicate that some data is ready. // Each permit represents a row of data (or that the underlying iterator is exhausted). private final Semaphore semaphore = new Semaphore(0); - private final CompletableFuture future; + private final Future future; + private volatile boolean cancelled; + private volatile boolean finished; public AsyncIterator(Iterator dataIterator, StatementClient client) + { + this(dataIterator, client, Optional.empty()); + } + + @VisibleForTesting + AsyncIterator(Iterator dataIterator, StatementClient client, Optional> queue) { requireNonNull(dataIterator, "dataIterator is null"); this.client = client; - this.future = CompletableFuture.runAsync(() -> { + this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS)); + this.cancelled = false; + this.finished = false; + this.future = executorService.submit(() -> { try { - while (dataIterator.hasNext()) { + while (!cancelled && dataIterator.hasNext()) { rowQueue.put(dataIterator.next()); semaphore.release(); } @@ -174,22 +187,46 @@ public AsyncIterator(Iterator dataIterator, StatementClient client) } finally { semaphore.release(); + finished = true; } - }, executorService); + }); } public void cancel() { + cancelled = true; future.cancel(true); + cleanup(); } public void interrupt(InterruptedException e) { - client.close(); + cleanup(); Thread.currentThread().interrupt(); throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e)); } + private void cleanup() + { + // When thread interruption is mis-handled by underlying implementation of `client`, the thread which + // is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish + // its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks. + client.close(); + rowQueue.clear(); + } + + @VisibleForTesting + Future getFuture() + { + return future; + } + + @VisibleForTesting + boolean isBackgroundThreadFinished() + { + return finished; + } + @Override protected T computeNext() { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java index 2685b1f81613..93ca43d2b0a8 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java @@ -24,6 +24,9 @@ import static java.lang.String.format; +/** + * An integration test for JDBC client interacting with Trino server. + */ public class TestJdbcResultSet extends BaseTestJdbcResultSet { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java new file mode 100644 index 000000000000..6975a09a98ac --- /dev/null +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import com.google.common.collect.ImmutableList; +import io.trino.client.ClientSelectedRole; +import io.trino.client.QueryData; +import io.trino.client.QueryStatusInfo; +import io.trino.client.StatementClient; +import io.trino.client.StatementStats; +import org.testng.annotations.Test; + +import java.time.ZoneId; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.assertTrue; + +/** + * A unit test for {@link TrinoResultSet}. + * + * @see TestJdbcResultSet for an integration test. + */ +public class TestTrinoResultSet +{ + @Test(timeOut = 10000) + public void testIteratorCancelWhenQueueNotFull() + throws Exception + { + AtomicReference thread = new AtomicReference<>(); + CountDownLatch interruptedButSwallowedLatch = new CountDownLatch(1); + MockAsyncIterator>> iterator = new MockAsyncIterator<>( + new Iterator>>() + { + @Override + public boolean hasNext() + { + return true; + } + + @Override + public Iterable> next() + { + thread.compareAndSet(null, Thread.currentThread()); + try { + TimeUnit.MILLISECONDS.sleep(1000); + } + catch (InterruptedException e) { + interruptedButSwallowedLatch.countDown(); + } + return ImmutableList.of((ImmutableList.of(new Object()))); + } + }, + new ArrayBlockingQueue<>(100)); + + while (thread.get() == null || thread.get().getState() != Thread.State.TIMED_WAITING) { + // wait for thread being waiting + } + iterator.cancel(); + while (!iterator.getFuture().isDone() || !iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + boolean interruptedButSwallowed = interruptedButSwallowedLatch.await(5000, TimeUnit.MILLISECONDS); + assertTrue(interruptedButSwallowed); + } + + @Test(timeOut = 10000) + public void testIteratorCancelWhenQueueIsFull() + throws Exception + { + BlockingQueue>> queue = new ArrayBlockingQueue<>(1); + queue.put(ImmutableList.of()); + // queue is full at the beginning + AtomicReference thread = new AtomicReference<>(); + MockAsyncIterator>> iterator = new MockAsyncIterator<>( + new Iterator>>() + { + @Override + public boolean hasNext() + { + return true; + } + + @Override + public Iterable> next() + { + thread.compareAndSet(null, Thread.currentThread()); + return ImmutableList.of((ImmutableList.of(new Object()))); + } + }, + queue); + + while (thread.get() == null || thread.get().getState() != Thread.State.WAITING) { + // wait for thread being waiting (for queue being not full) + TimeUnit.MILLISECONDS.sleep(10); + } + iterator.cancel(); + while (!iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + } + + private static class MockAsyncIterator + extends TrinoResultSet.AsyncIterator + { + public MockAsyncIterator(Iterator dataIterator, BlockingQueue queue) + { + super( + dataIterator, + new StatementClient() + { + @Override + public String getQuery() + { + throw new UnsupportedOperationException(); + } + + @Override + public ZoneId getTimeZone() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRunning() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClientAborted() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClientError() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinished() + { + throw new UnsupportedOperationException(); + } + + @Override + public StatementStats getStats() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryStatusInfo currentStatusInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryData currentData() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryStatusInfo finalStatusInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetCatalog() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetSchema() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetPath() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getResetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetRoles() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getAddedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getDeallocatedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getStartedTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClearTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean advance() + { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelLeafStage() + { + throw new UnsupportedOperationException(); + } + + @Override + public void close() + { + // do nothing + } + }, + Optional.of(queue)); + } + } +}