diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index f977a365ff89..c48c58c58246 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -77,6 +77,7 @@ public class QueryManagerConfig private int queryManagerExecutorPoolSize = 5; private int queryExecutorPoolSize = 1000; private int maxStateMachineCallbackThreads = 5; + private int maxSplitManagerCallbackThreads = 100; /** * default value is overwritten for fault tolerant execution in {@link #applyFaultTolerantExecutionDefaults()} @@ -394,6 +395,20 @@ public QueryManagerConfig setMaxStateMachineCallbackThreads(int maxStateMachineC return this; } + @Min(1) + public int getMaxSplitManagerCallbackThreads() + { + return maxSplitManagerCallbackThreads; + } + + @Config("query.max-split-manager-callback-threads") + @ConfigDescription("The maximum number of threads allowed to run splits generation callbacks concurrently") + public QueryManagerConfig setMaxSplitManagerCallbackThreads(int maxSplitManagerCallbackThreads) + { + this.maxSplitManagerCallbackThreads = maxSplitManagerCallbackThreads; + return this; + } + @NotNull @MinDuration("1s") public Duration getRemoteTaskMaxErrorDuration() diff --git a/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java b/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java index 084d6722fa99..da721ce70660 100644 --- a/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java +++ b/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java @@ -13,8 +13,11 @@ */ package io.trino.split; -import com.google.common.util.concurrent.Futures; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.AbstractFuture; +import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.opentelemetry.context.Context; import io.trino.metadata.Split; import io.trino.spi.connector.CatalogHandle; @@ -22,9 +25,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.Executor; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static java.util.Objects.requireNonNull; @@ -33,10 +38,12 @@ public class BufferingSplitSource { private final int bufferSize; private final SplitSource source; + private final Executor executor; - public BufferingSplitSource(SplitSource source, int bufferSize) + public BufferingSplitSource(SplitSource source, Executor executor, int bufferSize) { this.source = requireNonNull(source, "source is null"); + this.executor = requireNonNull(executor, "executor is null"); this.bufferSize = bufferSize; } @@ -50,7 +57,7 @@ public CatalogHandle getCatalogHandle() public ListenableFuture getNextBatch(int maxSize) { checkArgument(maxSize > 0, "Cannot fetch a batch of zero size"); - return GetNextBatch.fetchNextBatchAsync(source, Math.min(bufferSize, maxSize), maxSize); + return GetNextBatch.fetchNextBatchAsync(source, executor, Math.min(bufferSize, maxSize), maxSize); } @Override @@ -72,50 +79,108 @@ public Optional> getTableExecuteSplitsInfo() } private static class GetNextBatch + extends AbstractFuture { private final Context context = Context.current(); private final SplitSource splitSource; + private final Executor executor; private final int min; private final int max; + @GuardedBy("this") private final List splits = new ArrayList<>(); - private boolean noMoreSplits; + @GuardedBy("this") + private ListenableFuture nextBatchFuture; public static ListenableFuture fetchNextBatchAsync( SplitSource splitSource, + Executor executor, int min, int max) { - GetNextBatch getNextBatch = new GetNextBatch(splitSource, min, max); - ListenableFuture future = getNextBatch.fetchSplits(); - return Futures.transform(future, ignored -> new SplitBatch(getNextBatch.splits, getNextBatch.noMoreSplits), directExecutor()); + GetNextBatch getNextBatch = new GetNextBatch(splitSource, executor, min, max); + getNextBatch.fetchSplits(); + return getNextBatch; } - private GetNextBatch(SplitSource splitSource, int min, int max) + private GetNextBatch(SplitSource splitSource, Executor executor, int min, int max) { this.splitSource = requireNonNull(splitSource, "splitSource is null"); + this.executor = requireNonNull(executor, "executor is null"); checkArgument(min <= max, "Min splits greater than max splits"); this.min = min; this.max = max; } - private ListenableFuture fetchSplits() + private synchronized void fetchSplits() { - if (splits.size() >= min) { - return immediateVoidFuture(); - } - ListenableFuture future; + checkState(nextBatchFuture == null || nextBatchFuture.isDone(), "nextBatchFuture is expected to be done"); + try (var ignored = context.makeCurrent()) { - future = splitSource.getNextBatch(max - splits.size()); - } - return Futures.transformAsync(future, splitBatch -> { - splits.addAll(splitBatch.getSplits()); - if (splitBatch.isLastBatch()) { - noMoreSplits = true; - return immediateVoidFuture(); + nextBatchFuture = splitSource.getNextBatch(max - splits.size()); + // If the split source returns completed futures, we process them on + // directExecutor without chaining to avoid the overhead of going through separate executor + while (nextBatchFuture.isDone()) { + addCallback( + nextBatchFuture, + new FutureCallback<>() + { + @Override + public void onSuccess(SplitBatch splitBatch) + { + processBatch(splitBatch); + } + + @Override + public void onFailure(Throwable throwable) + { + setException(throwable); + } + }, + directExecutor()); + if (isDone()) { + return; + } + nextBatchFuture = splitSource.getNextBatch(max - splits.size()); } - return fetchSplits(); - }, directExecutor()); + } + + addCallback( + nextBatchFuture, + new FutureCallback<>() + { + @Override + public void onSuccess(SplitBatch splitBatch) + { + synchronized (GetNextBatch.this) { + if (processBatch(splitBatch)) { + return; + } + fetchSplits(); + } + } + + @Override + public void onFailure(Throwable throwable) + { + setException(throwable); + } + }, + executor); + } + + // Accumulates splits from the returned batch and returns whether + // sufficient splits have been buffered to satisfy min batch size + private synchronized boolean processBatch(SplitBatch splitBatch) + { + splits.addAll(splitBatch.getSplits()); + boolean isLastBatch = splitBatch.isLastBatch(); + if (splits.size() >= min || isLastBatch) { + set(new SplitBatch(ImmutableList.copyOf(splits), isLastBatch)); + splits.clear(); + return true; + } + return false; } } } diff --git a/core/trino-main/src/main/java/io/trino/split/SplitManager.java b/core/trino-main/src/main/java/io/trino/split/SplitManager.java index b6bc7a9c1544..70b388235b2c 100644 --- a/core/trino-main/src/main/java/io/trino/split/SplitManager.java +++ b/core/trino-main/src/main/java/io/trino/split/SplitManager.java @@ -14,6 +14,7 @@ package io.trino.split; import com.google.inject.Inject; +import io.airlift.concurrent.BoundedExecutor; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; @@ -31,15 +32,19 @@ import io.trino.tracing.TrinoAttributes; import java.util.Optional; +import java.util.concurrent.Executor; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; public class SplitManager { private final CatalogServiceProvider splitManagerProvider; private final Tracer tracer; private final int minScheduleSplitBatchSize; + private final Executor executor; @Inject public SplitManager(CatalogServiceProvider splitManagerProvider, Tracer tracer, QueryManagerConfig config) @@ -47,6 +52,7 @@ public SplitManager(CatalogServiceProvider splitManagerPr this.splitManagerProvider = requireNonNull(splitManagerProvider, "splitManagerProvider is null"); this.tracer = requireNonNull(tracer, "tracer is null"); this.minScheduleSplitBatchSize = config.getMinScheduleSplitBatchSize(); + this.executor = new BoundedExecutor(newCachedThreadPool(daemonThreadsNamed("splits-manager-callback-%s")), config.getMaxSplitManagerCallbackThreads()); } public SplitSource getSplits( @@ -77,7 +83,7 @@ public SplitSource getSplits( if (minScheduleSplitBatchSize > 1) { splitSource = new TracingSplitSource(splitSource, tracer, Optional.empty(), "split-batch"); - splitSource = new BufferingSplitSource(splitSource, minScheduleSplitBatchSize); + splitSource = new BufferingSplitSource(splitSource, executor, minScheduleSplitBatchSize); splitSource = new TracingSplitSource(splitSource, tracer, Optional.of(span), "split-buffer"); } else { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index b0958cb8db6c..dc62cc9f36ee 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -57,6 +57,7 @@ public void testDefaults() .setQueryManagerExecutorPoolSize(5) .setQueryExecutorPoolSize(1000) .setMaxStateMachineCallbackThreads(5) + .setMaxSplitManagerCallbackThreads(100) .setRemoteTaskMaxErrorDuration(new Duration(5, MINUTES)) .setRemoteTaskMaxCallbackThreads(1000) .setQueryExecutionPolicy("phased") @@ -132,6 +133,7 @@ public void testExplicitPropertyMappings() .put("query.manager-executor-pool-size", "11") .put("query.executor-pool-size", "111") .put("query.max-state-machine-callback-threads", "112") + .put("query.max-split-manager-callback-threads", "113") .put("query.remote-task.max-error-duration", "60s") .put("query.remote-task.max-callback-threads", "10") .put("query.execution-policy", "foo-bar-execution-policy") @@ -204,6 +206,7 @@ public void testExplicitPropertyMappings() .setQueryManagerExecutorPoolSize(11) .setQueryExecutorPoolSize(111) .setMaxStateMachineCallbackThreads(112) + .setMaxSplitManagerCallbackThreads(113) .setRemoteTaskMaxErrorDuration(new Duration(60, SECONDS)) .setRemoteTaskMaxCallbackThreads(10) .setQueryExecutionPolicy("foo-bar-execution-policy") diff --git a/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java b/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java index c5cdd6e46251..94078c43ce32 100644 --- a/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java +++ b/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java @@ -17,7 +17,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.trino.annotation.NotThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import io.trino.metadata.Split; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorSplit; @@ -33,7 +33,7 @@ import static io.trino.split.MockSplitSource.Action.FINISH; import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; -@NotThreadSafe +@ThreadSafe public class MockSplitSource implements SplitSource { @@ -58,14 +58,14 @@ public MockSplitSource() { } - public MockSplitSource setBatchSize(int batchSize) + public synchronized MockSplitSource setBatchSize(int batchSize) { checkArgument(atSplitDepletion == DO_NOTHING, "cannot modify batch size once split completion action is set"); this.batchSize = batchSize; return this; } - public MockSplitSource increaseAvailableSplits(int count) + public synchronized MockSplitSource increaseAvailableSplits(int count) { checkArgument(atSplitDepletion == DO_NOTHING, "cannot increase available splits once split completion action is set"); totalSplits += count; @@ -73,7 +73,7 @@ public MockSplitSource increaseAvailableSplits(int count) return this; } - public MockSplitSource atSplitCompletion(Action action) + public synchronized MockSplitSource atSplitCompletion(Action action) { atSplitDepletion = action; doGetNextBatch(); @@ -86,9 +86,13 @@ public CatalogHandle getCatalogHandle() throw new UnsupportedOperationException(); } - private void doGetNextBatch() + private synchronized void doGetNextBatch() { checkState(splitsProduced <= totalSplits); + if (nextBatchFuture.isDone()) { + // if nextBatchFuture is already done, we need to wait until new future is created through getNextBatch to produce splits + return; + } if (splitsProduced == totalSplits) { switch (atSplitDepletion) { case FAIL: @@ -111,7 +115,7 @@ private void doGetNextBatch() } @Override - public ListenableFuture getNextBatch(int maxSize) + public synchronized ListenableFuture getNextBatch(int maxSize) { checkState(nextBatchFuture.isDone(), "concurrent getNextBatch invocation"); nextBatchFuture = SettableFuture.create(); @@ -128,7 +132,7 @@ public void close() } @Override - public boolean isFinished() + public synchronized boolean isFinished() { return splitsProduced == totalSplits && atSplitDepletion == FINISH; } @@ -139,7 +143,7 @@ public Optional> getTableExecuteSplitsInfo() return Optional.empty(); } - public int getNextBatchInvocationCount() + public synchronized int getNextBatchInvocationCount() { return nextBatchInvocationCount; } diff --git a/core/trino-main/src/test/java/io/trino/split/TestBufferingSplitSource.java b/core/trino-main/src/test/java/io/trino/split/TestBufferingSplitSource.java index fc53a4ac1fbd..9356c5d4447e 100644 --- a/core/trino-main/src/test/java/io/trino/split/TestBufferingSplitSource.java +++ b/core/trino-main/src/test/java/io/trino/split/TestBufferingSplitSource.java @@ -15,16 +15,19 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.concurrent.BoundedExecutor; import io.trino.split.SplitSource.SplitBatch; import org.junit.jupiter.api.Test; -import java.util.concurrent.Future; +import java.util.concurrent.Executor; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.split.MockSplitSource.Action.FAIL; import static io.trino.split.MockSplitSource.Action.FINISH; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -32,6 +35,8 @@ public class TestBufferingSplitSource { + private static final Executor executor = new BoundedExecutor(newCachedThreadPool(daemonThreadsNamed(TestBufferingSplitSource.class.getSimpleName() + "-%s")), 10); + @Test public void testSlowSource() { @@ -39,14 +44,14 @@ public void testSlowSource() .setBatchSize(1) .increaseAvailableSplits(25) .atSplitCompletion(FINISH); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { - requireFutureValue(getNextBatch(source, 20)) + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { + getFutureValue(getNextBatch(source, 20)) .assertSize(10) .assertNoMoreSplits(false); - requireFutureValue(getNextBatch(source, 6)) + getFutureValue(getNextBatch(source, 6)) .assertSize(6) .assertNoMoreSplits(false); - requireFutureValue(getNextBatch(source, 20)) + getFutureValue(getNextBatch(source, 20)) .assertSize(9) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -61,11 +66,11 @@ public void testFastSource() .setBatchSize(11) .increaseAvailableSplits(22) .atSplitCompletion(FINISH); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { - requireFutureValue(getNextBatch(source, 200)) + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { + getFutureValue(getNextBatch(source, 200)) .assertSize(11) .assertNoMoreSplits(false); - requireFutureValue(getNextBatch(source, 200)) + getFutureValue(getNextBatch(source, 200)) .assertSize(11) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -73,14 +78,29 @@ public void testFastSource() } } + @Test + public void testNoStackOverFlow() + { + MockSplitSource mockSource = new MockSplitSource() + .setBatchSize(1) + .increaseAvailableSplits(10000) + .atSplitCompletion(FINISH); + try (SplitSource source = new BufferingSplitSource(mockSource, executor, Integer.MAX_VALUE)) { + while (!source.isFinished()) { + getFutureValue(getNextBatch(source, 1000)) + .assertSize(1000); + } + } + } + @Test public void testEmptySource() { MockSplitSource mockSource = new MockSplitSource() .setBatchSize(1) .atSplitCompletion(FINISH); - try (SplitSource source = new BufferingSplitSource(mockSource, 100)) { - requireFutureValue(getNextBatch(source, 200)) + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 100)) { + getFutureValue(getNextBatch(source, 200)) .assertSize(0) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -93,14 +113,14 @@ public void testBlocked() { MockSplitSource mockSource = new MockSplitSource() .setBatchSize(1); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { // Source has 0 out of 10 needed. ListenableFuture nextBatchFuture = getNextBatch(source, 10); assertFalse(nextBatchFuture.isDone()); mockSource.increaseAvailableSplits(9); assertFalse(nextBatchFuture.isDone()); mockSource.increaseAvailableSplits(1); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(10) .assertNoMoreSplits(false); @@ -108,7 +128,7 @@ public void testBlocked() nextBatchFuture = getNextBatch(source, 10); assertFalse(nextBatchFuture.isDone()); mockSource.atSplitCompletion(FINISH); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(0) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -116,13 +136,13 @@ public void testBlocked() mockSource = new MockSplitSource() .setBatchSize(1); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { // Source has 1 out of 10 needed. mockSource.increaseAvailableSplits(1); ListenableFuture nextBatchFuture = getNextBatch(source, 10); assertFalse(nextBatchFuture.isDone()); mockSource.increaseAvailableSplits(9); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(10) .assertNoMoreSplits(false); @@ -131,7 +151,7 @@ public void testBlocked() mockSource.increaseAvailableSplits(5); assertFalse(nextBatchFuture.isDone()); mockSource.atSplitCompletion(FINISH); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(5) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -139,13 +159,13 @@ public void testBlocked() mockSource = new MockSplitSource() .setBatchSize(1); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { // Source has 9 out of 10 needed. mockSource.increaseAvailableSplits(9); ListenableFuture nextBatchFuture = getNextBatch(source, 10); assertFalse(nextBatchFuture.isDone()); mockSource.increaseAvailableSplits(1); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(10) .assertNoMoreSplits(false); @@ -161,12 +181,12 @@ public void testBlocked() // Fast source: source produce 8 before, and 8 after invocation. BufferedSource should return all 16 at once. mockSource = new MockSplitSource() .setBatchSize(8); - try (SplitSource source = new BufferingSplitSource(mockSource, 10)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 10)) { mockSource.increaseAvailableSplits(8); ListenableFuture nextBatchFuture = getNextBatch(source, 20); assertFalse(nextBatchFuture.isDone()); mockSource.increaseAvailableSplits(8); - requireFutureValue(nextBatchFuture) + getFutureValue(nextBatchFuture) .assertSize(16) .assertNoMoreSplits(false); } @@ -178,8 +198,8 @@ public void testFinishedSetWithoutIndicationFromSplitBatch() MockSplitSource mockSource = new MockSplitSource() .setBatchSize(1) .increaseAvailableSplits(1); - try (SplitSource source = new BufferingSplitSource(mockSource, 100)) { - requireFutureValue(getNextBatch(source, 1)) + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 100)) { + getFutureValue(getNextBatch(source, 1)) .assertSize(1) .assertNoMoreSplits(false); assertFalse(source.isFinished()); @@ -189,7 +209,7 @@ public void testFinishedSetWithoutIndicationFromSplitBatch() // In this case, the preceding getNextBatch() indicates the noMoreSplits is false, // but the next isFinished call will return true. mockSource.atSplitCompletion(FINISH); - requireFutureValue(getNextBatch(source, 1)) + getFutureValue(getNextBatch(source, 1)) .assertSize(0) .assertNoMoreSplits(true); assertTrue(source.isFinished()); @@ -203,7 +223,7 @@ public void testFailImmediate() MockSplitSource mockSource = new MockSplitSource() .setBatchSize(1) .atSplitCompletion(FAIL); - try (SplitSource source = new BufferingSplitSource(mockSource, 100)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 100)) { assertFutureFailsWithMockFailure(getNextBatch(source, 200)); assertEquals(mockSource.getNextBatchInvocationCount(), 1); } @@ -216,7 +236,7 @@ public void testFail() .setBatchSize(1) .increaseAvailableSplits(1) .atSplitCompletion(FAIL); - try (SplitSource source = new BufferingSplitSource(mockSource, 100)) { + try (SplitSource source = new BufferingSplitSource(mockSource, executor, 100)) { assertFutureFailsWithMockFailure(getNextBatch(source, 2)); assertEquals(mockSource.getNextBatchInvocationCount(), 2); } @@ -224,16 +244,10 @@ public void testFail() private static void assertFutureFailsWithMockFailure(ListenableFuture future) { - assertTrue(future.isDone()); assertThatThrownBy(future::get) .hasMessageContaining("Mock failure"); } - private static T requireFutureValue(Future future) - { - return tryGetFutureValue(future).orElseThrow(AssertionError::new); - } - private static ListenableFuture getNextBatch(SplitSource splitSource, int maxSize) { ListenableFuture future = splitSource.getNextBatch(maxSize);