From 2b63a621185372ad446fa9c9ae2cda42343867c0 Mon Sep 17 00:00:00 2001 From: xiacongling Date: Thu, 25 Aug 2022 01:05:34 +0800 Subject: [PATCH] Fix closing TrinoResultSet background thread ResultSet cannot be properly closed because the inner Thread cannot be interrupted and stop data row iteration. That will lead to thread and memory leaks on the client side. This patch uses FutureTask, which is created by ThreadPoolExecutor, instead of CompletableFuture to make sure `Thread.interrupt()` can be invoked as expected. And for the case that interruption is not properly handled by the underlying StatementClient, a status check is added to the loop condition so that loop can terminate and thread can be released. --- .../java/io/trino/jdbc/TrinoResultSet.java | 53 +++- .../java/io/trino/jdbc/TestJdbcResultSet.java | 3 + .../io/trino/jdbc/TestTrinoResultSet.java | 272 ++++++++++++++++++ 3 files changed, 320 insertions(+), 8 deletions(-) create mode 100644 client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java index 35fac6a3e48d..d36ad5194040 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java @@ -13,6 +13,7 @@ */ package io.trino.jdbc; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Streams; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -29,9 +30,9 @@ import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.function.Consumer; import java.util.stream.Stream; @@ -144,7 +145,8 @@ private static Iterator flatten(Iterator> iterator, long maxR return stream.iterator(); } - private static class AsyncIterator + @VisibleForTesting + static class AsyncIterator extends AbstractIterator { private static final int MAX_QUEUED_ROWS = 50_000; @@ -152,19 +154,30 @@ private static class AsyncIterator new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build()); private final StatementClient client; - private final BlockingQueue rowQueue = new ArrayBlockingQueue<>(MAX_QUEUED_ROWS); + private final BlockingQueue rowQueue; // Semaphore to indicate that some data is ready. // Each permit represents a row of data (or that the underlying iterator is exhausted). private final Semaphore semaphore = new Semaphore(0); - private final CompletableFuture future; + private final Future future; + private volatile boolean cancelled; + private volatile boolean finished; public AsyncIterator(Iterator dataIterator, StatementClient client) + { + this(dataIterator, client, Optional.empty()); + } + + @VisibleForTesting + AsyncIterator(Iterator dataIterator, StatementClient client, Optional> queue) { requireNonNull(dataIterator, "dataIterator is null"); this.client = client; - this.future = CompletableFuture.runAsync(() -> { + this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS)); + this.cancelled = false; + this.finished = false; + this.future = executorService.submit(() -> { try { - while (dataIterator.hasNext()) { + while (!cancelled && dataIterator.hasNext()) { rowQueue.put(dataIterator.next()); semaphore.release(); } @@ -174,22 +187,46 @@ public AsyncIterator(Iterator dataIterator, StatementClient client) } finally { semaphore.release(); + finished = true; } - }, executorService); + }); } public void cancel() { + cancelled = true; future.cancel(true); + cleanup(); } public void interrupt(InterruptedException e) { - client.close(); + cleanup(); Thread.currentThread().interrupt(); throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e)); } + private void cleanup() + { + // When thread interruption is mis-handled by underlying implementation of `client`, the thread which + // is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish + // its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks. + client.close(); + rowQueue.clear(); + } + + @VisibleForTesting + Future getFuture() + { + return future; + } + + @VisibleForTesting + boolean isBackgroundThreadFinished() + { + return finished; + } + @Override protected T computeNext() { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java index 2685b1f81613..93ca43d2b0a8 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcResultSet.java @@ -24,6 +24,9 @@ import static java.lang.String.format; +/** + * An integration test for JDBC client interacting with Trino server. + */ public class TestJdbcResultSet extends BaseTestJdbcResultSet { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java new file mode 100644 index 000000000000..6975a09a98ac --- /dev/null +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import com.google.common.collect.ImmutableList; +import io.trino.client.ClientSelectedRole; +import io.trino.client.QueryData; +import io.trino.client.QueryStatusInfo; +import io.trino.client.StatementClient; +import io.trino.client.StatementStats; +import org.testng.annotations.Test; + +import java.time.ZoneId; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.assertTrue; + +/** + * A unit test for {@link TrinoResultSet}. + * + * @see TestJdbcResultSet for an integration test. + */ +public class TestTrinoResultSet +{ + @Test(timeOut = 10000) + public void testIteratorCancelWhenQueueNotFull() + throws Exception + { + AtomicReference thread = new AtomicReference<>(); + CountDownLatch interruptedButSwallowedLatch = new CountDownLatch(1); + MockAsyncIterator>> iterator = new MockAsyncIterator<>( + new Iterator>>() + { + @Override + public boolean hasNext() + { + return true; + } + + @Override + public Iterable> next() + { + thread.compareAndSet(null, Thread.currentThread()); + try { + TimeUnit.MILLISECONDS.sleep(1000); + } + catch (InterruptedException e) { + interruptedButSwallowedLatch.countDown(); + } + return ImmutableList.of((ImmutableList.of(new Object()))); + } + }, + new ArrayBlockingQueue<>(100)); + + while (thread.get() == null || thread.get().getState() != Thread.State.TIMED_WAITING) { + // wait for thread being waiting + } + iterator.cancel(); + while (!iterator.getFuture().isDone() || !iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + boolean interruptedButSwallowed = interruptedButSwallowedLatch.await(5000, TimeUnit.MILLISECONDS); + assertTrue(interruptedButSwallowed); + } + + @Test(timeOut = 10000) + public void testIteratorCancelWhenQueueIsFull() + throws Exception + { + BlockingQueue>> queue = new ArrayBlockingQueue<>(1); + queue.put(ImmutableList.of()); + // queue is full at the beginning + AtomicReference thread = new AtomicReference<>(); + MockAsyncIterator>> iterator = new MockAsyncIterator<>( + new Iterator>>() + { + @Override + public boolean hasNext() + { + return true; + } + + @Override + public Iterable> next() + { + thread.compareAndSet(null, Thread.currentThread()); + return ImmutableList.of((ImmutableList.of(new Object()))); + } + }, + queue); + + while (thread.get() == null || thread.get().getState() != Thread.State.WAITING) { + // wait for thread being waiting (for queue being not full) + TimeUnit.MILLISECONDS.sleep(10); + } + iterator.cancel(); + while (!iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + } + + private static class MockAsyncIterator + extends TrinoResultSet.AsyncIterator + { + public MockAsyncIterator(Iterator dataIterator, BlockingQueue queue) + { + super( + dataIterator, + new StatementClient() + { + @Override + public String getQuery() + { + throw new UnsupportedOperationException(); + } + + @Override + public ZoneId getTimeZone() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRunning() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClientAborted() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClientError() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinished() + { + throw new UnsupportedOperationException(); + } + + @Override + public StatementStats getStats() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryStatusInfo currentStatusInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryData currentData() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryStatusInfo finalStatusInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetCatalog() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetSchema() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetPath() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getResetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetRoles() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getAddedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getDeallocatedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getStartedTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClearTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean advance() + { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelLeafStage() + { + throw new UnsupportedOperationException(); + } + + @Override + public void close() + { + // do nothing + } + }, + Optional.of(queue)); + } + } +}