diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index f637e20ad618..397b6740c52a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import com.google.common.collect.Ordering; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.concurrent.MoreFutures; @@ -80,6 +81,7 @@ import static com.google.common.util.concurrent.Futures.allAsList; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.asVoid; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.MoreFutures.toListenableFuture; @@ -285,18 +287,29 @@ public synchronized void schedule() while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) { while (queuedPartitions.isEmpty() && pendingPartitions.size() < maxTasksWaitingForNodePerStage && !taskSource.isFinished()) { - List tasks = taskSource.getMoreTasks(); - for (TaskDescriptor task : tasks) { - queuedPartitions.add(task.getPartitionId()); - allPartitions.add(task.getPartitionId()); - taskDescriptorStorage.put(stage.getStageId(), task); - sinkExchange.ifPresent(exchange -> { - ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId()); - partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle); - }); - } - if (taskSource.isFinished()) { - sinkExchange.ifPresent(Exchange::noMoreSinks); + ListenableFuture tasksPopulatedFuture = Futures.transform( + taskSource.getMoreTasks(), + tasks -> { + synchronized (this) { + for (TaskDescriptor task : tasks) { + queuedPartitions.add(task.getPartitionId()); + allPartitions.add(task.getPartitionId()); + taskDescriptorStorage.put(stage.getStageId(), task); + sinkExchange.ifPresent(exchange -> { + ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId()); + partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle); + }); + } + if (taskSource.isFinished()) { + sinkExchange.ifPresent(Exchange::noMoreSinks); + } + return null; + } + }, + directExecutor()); + if (!tasksPopulatedFuture.isDone()) { + blocked = tasksPopulatedFuture; + return; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index 9c5d78d39de7..7fe09fcec382 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -26,11 +26,15 @@ import com.google.common.collect.Multimaps; import com.google.common.collect.SetMultimap; import com.google.common.collect.Sets; +import com.google.common.util.concurrent.AbstractFuture; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.trino.Session; import io.trino.connector.CatalogName; +import io.trino.execution.ForQueryExecution; import io.trino.execution.Lifespan; import io.trino.execution.QueryManagerConfig; import io.trino.execution.TableExecuteContext; @@ -52,6 +56,7 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RemoteSourceNode; +import javax.annotation.concurrent.GuardedBy; import javax.inject.Inject; import java.util.ArrayList; @@ -62,6 +67,8 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.function.LongConsumer; import static com.google.common.base.Preconditions.checkArgument; @@ -71,8 +78,10 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Sets.newIdentityHashSet; import static com.google.common.collect.Sets.union; +import static com.google.common.util.concurrent.Futures.addCallback; +import static com.google.common.util.concurrent.Futures.allAsList; +import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; -import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinTaskSplitCount; @@ -97,27 +106,32 @@ public class StageTaskSourceFactory private final SplitSourceFactory splitSourceFactory; private final TableExecuteContextManager tableExecuteContextManager; private final int splitBatchSize; + private final Executor executor; @Inject public StageTaskSourceFactory( SplitSourceFactory splitSourceFactory, TableExecuteContextManager tableExecuteContextManager, - QueryManagerConfig queryManagerConfig) + QueryManagerConfig queryManagerConfig, + @ForQueryExecution ExecutorService executor) { this( splitSourceFactory, tableExecuteContextManager, - requireNonNull(queryManagerConfig, "queryManagerConfig is null").getScheduleSplitBatchSize()); + requireNonNull(queryManagerConfig, "queryManagerConfig is null").getScheduleSplitBatchSize(), + executor); } public StageTaskSourceFactory( SplitSourceFactory splitSourceFactory, TableExecuteContextManager tableExecuteContextManager, - int splitBatchSize) + int splitBatchSize, + ExecutorService executor) { this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.splitBatchSize = splitBatchSize; + this.executor = requireNonNull(executor, "executor is null"); } @Override @@ -155,7 +169,8 @@ public TaskSource create( bucketToPartitionMap.orElseThrow(() -> new IllegalArgumentException("bucketToPartitionMap is expected to be present for hash distributed stages")), bucketNodeMap, getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), - getFaultTolerantExecutionTargetTaskInputSize(session)); + getFaultTolerantExecutionTargetTaskInputSize(session), + executor); } if (partitioning.equals(SOURCE_DISTRIBUTION)) { return SourceDistributionTaskSource.create( @@ -168,7 +183,8 @@ public TaskSource create( getSplitTimeRecorder, getFaultTolerantExecutionMinTaskSplitCount(session), getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), - getFaultTolerantExecutionMaxTaskSplitCount(session)); + getFaultTolerantExecutionMaxTaskSplitCount(session), + executor); } // other partitioning handles are not expected to be set as a fragment partitioning @@ -196,15 +212,18 @@ public SingleDistributionTaskSource(ListMultimap getMoreTasks() + public ListenableFuture> getMoreTasks() { + if (finished) { + return immediateFuture(ImmutableList.of()); + } List result = ImmutableList.of(new TaskDescriptor( 0, ImmutableListMultimap.of(), exchangeSourceHandles, new NodeRequirements(Optional.empty(), ImmutableSet.of(), taskMemory))); finished = true; - return result; + return immediateFuture(result); } @Override @@ -273,8 +292,11 @@ public ArbitraryDistributionTaskSource( } @Override - public List getMoreTasks() + public ListenableFuture> getMoreTasks() { + if (finished) { + return immediateFuture(ImmutableList.of()); + } NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), ImmutableSet.of(), taskMemory); ImmutableList.Builder result = ImmutableList.builder(); @@ -322,7 +344,7 @@ public List getMoreTasks() } finished = true; - return result.build(); + return immediateFuture(result.build()); } @Override @@ -353,8 +375,13 @@ public static class HashDistributionTaskSource private final Optional catalogRequirement; private final long targetPartitionSourceSizeInBytes; // compared data read from ExchangeSources private final long targetPartitionSplitWeight; // compared against splits from SplitSources + private final Executor executor; + @GuardedBy("this") + private ListenableFuture> loadedSplitsFuture; + @GuardedBy("this") private boolean finished; + @GuardedBy("this") private boolean closed; public static HashDistributionTaskSource create( @@ -368,7 +395,8 @@ public static HashDistributionTaskSource create( int[] bucketToPartitionMap, Optional bucketNodeMap, long targetPartitionSplitWeight, - DataSize targetPartitionSourceSize) + DataSize targetPartitionSourceSize, + Executor executor) { checkArgument(bucketNodeMap.isPresent() || fragment.getPartitionedSources().isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)"); Map splitSources = splitSourceFactory.createSplitSources(session, fragment); @@ -384,7 +412,8 @@ public static HashDistributionTaskSource create( bucketNodeMap, fragment.getPartitioning().getConnectorId(), targetPartitionSplitWeight, targetPartitionSourceSize, - getFaultTolerantExecutionDefaultTaskMemory(session)); + getFaultTolerantExecutionDefaultTaskMemory(session), + executor); } public HashDistributionTaskSource( @@ -399,7 +428,8 @@ public HashDistributionTaskSource( Optional catalogRequirement, long targetPartitionSplitWeight, DataSize targetPartitionSourceSize, - DataSize taskMemory) + DataSize taskMemory, + Executor executor) { this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null")); this.exchangeForHandle = new IdentityHashMap<>(); @@ -415,105 +445,107 @@ public HashDistributionTaskSource( this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); this.targetPartitionSourceSizeInBytes = requireNonNull(targetPartitionSourceSize, "targetPartitionSourceSize is null").toBytes(); this.targetPartitionSplitWeight = targetPartitionSplitWeight; + this.executor = requireNonNull(executor, "executor is null"); } @Override - public List getMoreTasks() + public synchronized ListenableFuture> getMoreTasks() { if (finished || closed) { - return ImmutableList.of(); + return immediateFuture(ImmutableList.of()); } - - Map> partitionToSplitsMap = new HashMap<>(); - SetMultimap partitionToNodeMap = HashMultimap.create(); - for (Map.Entry entry : splitSources.entrySet()) { - SplitSource splitSource = entry.getValue(); - BucketNodeMap bucketNodeMap = this.bucketNodeMap - .orElseThrow(() -> new VerifyException("bucket to node map is expected to be present")); - while (!splitSource.isFinished()) { - ListenableFuture splitBatchFuture = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); - - long start = System.nanoTime(); - addSuccessCallback(splitBatchFuture, () -> getSplitTimeRecorder.accept(start)); - - SplitBatch splitBatch = getFutureValue(splitBatchFuture); - - for (Split split : splitBatch.getSplits()) { - int bucket = bucketNodeMap.getBucket(split); - int partition = getPartitionForBucket(bucket); - - if (!bucketNodeMap.isDynamic()) { - HostAddress requiredAddress = bucketNodeMap.getAssignedNode(split).get().getHostAndPort(); - Set existingRequirement = partitionToNodeMap.get(partition); - if (existingRequirement.isEmpty()) { - existingRequirement.add(requiredAddress); - } - else { - checkState( - existingRequirement.contains(requiredAddress), - "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", - partition, - existingRequirement, - requiredAddress); - existingRequirement.removeIf(host -> !host.equals(requiredAddress)); + checkState(loadedSplitsFuture == null, "getMoreTasks called again while splits are being loaded"); + + List> splitSourceCompletionFutures = splitSources.entrySet().stream() + .map(entry -> { + SplitLoadingFuture future = new SplitLoadingFuture(entry.getKey(), entry.getValue(), splitBatchSize, getSplitTimeRecorder, executor); + future.load(); + return future; + }) + .collect(toImmutableList()); + + loadedSplitsFuture = allAsList(splitSourceCompletionFutures); + return Futures.transform( + loadedSplitsFuture, + loadedSplitsList -> { + synchronized (this) { + Map> partitionToSplitsMap = new HashMap<>(); + SetMultimap partitionToNodeMap = HashMultimap.create(); + for (LoadedSplits loadedSplits : loadedSplitsList) { + BucketNodeMap bucketNodeMap = this.bucketNodeMap + .orElseThrow(() -> new VerifyException("bucket to node map is expected to be present")); + for (Split split : loadedSplits.getSplits()) { + int bucket = bucketNodeMap.getBucket(split); + int partition = getPartitionForBucket(bucket); + + if (!bucketNodeMap.isDynamic()) { + HostAddress requiredAddress = bucketNodeMap.getAssignedNode(split).get().getHostAndPort(); + Set existingRequirement = partitionToNodeMap.get(partition); + if (existingRequirement.isEmpty()) { + existingRequirement.add(requiredAddress); + } + else { + checkState( + existingRequirement.contains(requiredAddress), + "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", + partition, + existingRequirement, + requiredAddress); + existingRequirement.removeIf(host -> !host.equals(requiredAddress)); + } + } + + if (!split.isRemotelyAccessible()) { + Set requiredAddresses = ImmutableSet.copyOf(split.getAddresses()); + verify(!requiredAddresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty: %s", split); + Set existingRequirement = partitionToNodeMap.get(partition); + if (existingRequirement.isEmpty()) { + existingRequirement.addAll(requiredAddresses); + } + else { + Set intersection = Sets.intersection(requiredAddresses, existingRequirement); + checkState( + !intersection.isEmpty(), + "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", + partition, + existingRequirement, + requiredAddresses); + partitionToNodeMap.replaceValues(partition, ImmutableSet.copyOf(intersection)); + } + } + + Multimap partitionSplits = partitionToSplitsMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); + partitionSplits.put(loadedSplits.getPlanNodeId(), split); + } } - } - if (!split.isRemotelyAccessible()) { - Set requiredAddresses = ImmutableSet.copyOf(split.getAddresses()); - verify(!requiredAddresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty: %s", split); - Set existingRequirement = partitionToNodeMap.get(partition); - if (existingRequirement.isEmpty()) { - existingRequirement.addAll(requiredAddresses); - } - else { - Set intersection = Sets.intersection(requiredAddresses, existingRequirement); - checkState( - !intersection.isEmpty(), - "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", - partition, - existingRequirement, - requiredAddresses); - partitionToNodeMap.replaceValues(partition, ImmutableSet.copyOf(intersection)); + Map> partitionToExchangeSourceHandlesMap = new HashMap<>(); + for (Map.Entry entry : partitionedExchangeSourceHandles.entries()) { + PlanNodeId planNodeId = entry.getKey(); + ExchangeSourceHandle handle = entry.getValue(); + int partition = handle.getPartitionId(); + Multimap partitionSourceHandles = partitionToExchangeSourceHandlesMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); + partitionSourceHandles.put(planNodeId, handle); } - } - - Multimap partitionSplits = partitionToSplitsMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); - partitionSplits.put(entry.getKey(), split); - } - - if (splitBatch.isLastBatch()) { - splitSource.close(); - break; - } - } - } - - Map> partitionToExchangeSourceHandlesMap = new HashMap<>(); - for (Map.Entry entry : partitionedExchangeSourceHandles.entries()) { - PlanNodeId planNodeId = entry.getKey(); - ExchangeSourceHandle handle = entry.getValue(); - int partition = handle.getPartitionId(); - Multimap partitionSourceHandles = partitionToExchangeSourceHandlesMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); - partitionSourceHandles.put(planNodeId, handle); - } - - int taskPartitionId = 0; - ImmutableList.Builder partitionTasks = ImmutableList.builder(); - for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) { - ListMultimap splits = partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of()); - ListMultimap exchangeSourceHandles = ImmutableListMultimap.builder() - .putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of())) - // replicated exchange source will be added in postprocessTasks below - .build(); - Set hostRequirement = partitionToNodeMap.get(partition); - partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement, taskMemory))); - } - List result = postprocessTasks(partitionTasks.build()); + int taskPartitionId = 0; + ImmutableList.Builder partitionTasks = ImmutableList.builder(); + for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) { + ListMultimap splits = partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of()); + ListMultimap exchangeSourceHandles = ImmutableListMultimap.builder() + .putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of())) + // replicated exchange source will be added in postprocessTasks below + .build(); + Set hostRequirement = partitionToNodeMap.get(partition); + partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement, taskMemory))); + } - finished = true; - return result; + List result = postprocessTasks(partitionTasks.build()); + finished = true; + return result; + } + }, + executor); } private List postprocessTasks(List tasks) @@ -584,13 +616,13 @@ private int getPartitionForBucket(int bucket) } @Override - public boolean isFinished() + public synchronized boolean isFinished() { return finished; } @Override - public void close() + public synchronized void close() { if (closed) { return; @@ -622,13 +654,21 @@ public static class SourceDistributionTaskSource private final long targetPartitionSplitWeight; private final int maxPartitionSplitCount; private final DataSize taskMemory; + private final Executor executor; + @GuardedBy("this") private final Set remotelyAccessibleSplitBuffer = newIdentityHashSet(); + @GuardedBy("this") private final Map> locallyAccessibleSplitBuffer = new HashMap<>(); + @GuardedBy("this") private int currentPartitionId; + @GuardedBy("this") private boolean finished; + @GuardedBy("this") private boolean closed; + @GuardedBy("this") + private ListenableFuture currentSplitBatchFuture = immediateFuture(null); public static SourceDistributionTaskSource create( Session session, @@ -640,7 +680,8 @@ public static SourceDistributionTaskSource create( LongConsumer getSplitTimeRecorder, int minPartitionSplitCount, long targetPartitionSplitWeight, - int maxPartitionSplitCount) + int maxPartitionSplitCount, + Executor executor) { checkArgument(fragment.getPartitionedSources().size() == 1, "single partitioned source is expected, got: %s", fragment.getPartitionedSources()); @@ -666,7 +707,8 @@ public static SourceDistributionTaskSource create( minPartitionSplitCount, targetPartitionSplitWeight, maxPartitionSplitCount, - getFaultTolerantExecutionDefaultTaskMemory(session)); + getFaultTolerantExecutionDefaultTaskMemory(session), + executor); } public SourceDistributionTaskSource( @@ -681,7 +723,8 @@ public SourceDistributionTaskSource( int minPartitionSplitCount, long targetPartitionSplitWeight, int maxPartitionSplitCount, - DataSize taskMemory) + DataSize taskMemory, + Executor executor) { this.queryId = requireNonNull(queryId, "queryId is null"); this.partitionedSourceNodeId = requireNonNull(partitionedSourceNodeId, "partitionedSourceNodeId is null"); @@ -702,77 +745,80 @@ public SourceDistributionTaskSource( minPartitionSplitCount); this.maxPartitionSplitCount = maxPartitionSplitCount; this.taskMemory = requireNonNull(taskMemory, "taskMemory is null"); + this.executor = requireNonNull(executor, "executor is null"); } @Override - public List getMoreTasks() + public synchronized ListenableFuture> getMoreTasks() { if (finished || closed) { - return ImmutableList.of(); + return immediateFuture(ImmutableList.of()); } - List result = new ArrayList<>(); - - boolean splitSourceFinished = false; - while (result.isEmpty()) { - ListenableFuture splitBatchFuture = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); - - long start = System.nanoTime(); - addSuccessCallback(splitBatchFuture, () -> getSplitTimeRecorder.accept(start)); + checkState(currentSplitBatchFuture.isDone(), "getMoreTasks called again before the previous batch of splits was ready"); + currentSplitBatchFuture = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); + + long start = System.nanoTime(); + addSuccessCallback(currentSplitBatchFuture, () -> getSplitTimeRecorder.accept(start)); + + return Futures.transform( + currentSplitBatchFuture, + splitBatch -> { + synchronized (this) { + for (Split split : splitBatch.getSplits()) { + if (split.isRemotelyAccessible()) { + remotelyAccessibleSplitBuffer.add(split); + } + else { + List addresses = split.getAddresses(); + checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty"); + for (HostAddress hostAddress : addresses) { + locallyAccessibleSplitBuffer.computeIfAbsent(hostAddress, key -> newIdentityHashSet()).add(split); + } + } + } - SplitBatch splitBatch = getFutureValue(splitBatchFuture); - List splits = splitBatch.getSplits(); + ImmutableList.Builder readyTasksBuilder = ImmutableList.builder(); + boolean isLastBatch = splitBatch.isLastBatch(); + readyTasksBuilder.addAll(getReadyTasks( + remotelyAccessibleSplitBuffer, + ImmutableList.of(), + new NodeRequirements(catalogRequirement, ImmutableSet.of(), taskMemory), + isLastBatch)); + for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { + readyTasksBuilder.addAll(getReadyTasks( + locallyAccessibleSplitBuffer.get(remoteHost), + locallyAccessibleSplitBuffer.entrySet().stream() + .filter(entry -> !entry.getKey().equals(remoteHost)) + .map(Map.Entry::getValue) + .collect(toImmutableList()), + new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost), taskMemory), + isLastBatch)); + } + List readyTasks = readyTasksBuilder.build(); + + if (isLastBatch) { + Optional> tableExecuteSplitsInfo = splitSource.getTableExecuteSplitsInfo(); + + // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. + tableExecuteSplitsInfo.ifPresent(info -> { + TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(queryId); + tableExecuteContext.setSplitsInfo(info); + }); + + try { + splitSource.close(); + } + catch (RuntimeException e) { + log.error(e, "Error closing split source"); + } + finished = true; + } - for (Split split : splits) { - if (split.isRemotelyAccessible()) { - remotelyAccessibleSplitBuffer.add(split); - } - else { - List addresses = split.getAddresses(); - checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty"); - for (HostAddress hostAddress : addresses) { - locallyAccessibleSplitBuffer.computeIfAbsent(hostAddress, key -> newIdentityHashSet()).add(split); + return readyTasks; } - } - } - - splitSourceFinished = splitSource.isFinished(); - - result.addAll(getReadyTasks( - remotelyAccessibleSplitBuffer, - ImmutableList.of(), - new NodeRequirements(catalogRequirement, ImmutableSet.of(), taskMemory), - splitSourceFinished)); - for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { - result.addAll(getReadyTasks( - locallyAccessibleSplitBuffer.get(remoteHost), - locallyAccessibleSplitBuffer.entrySet().stream() - .filter(entry -> !entry.getKey().equals(remoteHost)) - .map(Map.Entry::getValue) - .collect(toImmutableList()), - new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost), taskMemory), - splitSourceFinished)); - } - - if (splitSourceFinished) { - break; - } - } - - if (splitSourceFinished) { - Optional> tableExecuteSplitsInfo = splitSource.getTableExecuteSplitsInfo(); - - // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. - tableExecuteSplitsInfo.ifPresent(info -> { - TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(queryId); - tableExecuteContext.setSplitsInfo(info); - }); - - finished = true; - splitSource.close(); - } - - return ImmutableList.copyOf(result); + }, + executor); } private List getReadyTasks(Set splits, List> otherSplitSets, NodeRequirements nodeRequirements, boolean includeRemainder) @@ -818,7 +864,7 @@ private Optional getReadyTask(Set splits, List return Optional.empty(); } - private TaskDescriptor buildTaskDescriptor(Collection splits, NodeRequirements nodeRequirements) + private synchronized TaskDescriptor buildTaskDescriptor(Collection splits, NodeRequirements nodeRequirements) { return new TaskDescriptor( currentPartitionId++, @@ -828,13 +874,13 @@ private TaskDescriptor buildTaskDescriptor(Collection splits, NodeRequire } @Override - public boolean isFinished() + public synchronized boolean isFinished() { return finished; } @Override - public void close() + public synchronized void close() { if (closed) { return; @@ -904,4 +950,104 @@ private static ListMultimap getInputsForRemote } return result.build(); } + + private static class LoadedSplits + { + private final PlanNodeId planNodeId; + private final List splits; + + private LoadedSplits(PlanNodeId planNodeId, List splits) + { + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + } + + public PlanNodeId getPlanNodeId() + { + return planNodeId; + } + + public List getSplits() + { + return splits; + } + } + + private static class SplitLoadingFuture + extends AbstractFuture + { + private final PlanNodeId planNodeId; + private final SplitSource splitSource; + private final int splitBatchSize; + private final LongConsumer getSplitTimeRecorder; + private final Executor executor; + @GuardedBy("this") + private final List loadedSplits = new ArrayList<>(); + @GuardedBy("this") + private ListenableFuture currentSplitBatch = immediateFuture(null); + + SplitLoadingFuture(PlanNodeId planNodeId, SplitSource splitSource, int splitBatchSize, LongConsumer getSplitTimeRecorder, Executor executor) + { + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.splitSource = requireNonNull(splitSource, "splitSource is null"); + this.splitBatchSize = splitBatchSize; + this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); + this.executor = requireNonNull(executor, "executor is null"); + } + + // Called to initiate loading and to load next batch if not finished + public synchronized void load() + { + if (currentSplitBatch == null) { + checkState(isCancelled(), "SplitLoadingFuture should be in cancelled state"); + return; + } + checkState(currentSplitBatch.isDone(), "next batch of splits requested before previous batch is done"); + currentSplitBatch = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); + + long start = System.nanoTime(); + addCallback( + currentSplitBatch, + new FutureCallback<>() + { + @Override + public void onSuccess(SplitBatch splitBatch) + { + getSplitTimeRecorder.accept(start); + synchronized (SplitLoadingFuture.this) { + loadedSplits.addAll(splitBatch.getSplits()); + + if (splitBatch.isLastBatch()) { + set(new LoadedSplits(planNodeId, loadedSplits)); + try { + splitSource.close(); + } + catch (RuntimeException e) { + log.error(e, "Error closing split source"); + } + } + else { + load(); + } + } + } + + @Override + public void onFailure(Throwable throwable) + { + setException(throwable); + } + }, + executor); + } + + @Override + protected synchronized void interruptTask() + { + if (currentSplitBatch != null) { + currentSplitBatch.cancel(true); + currentSplitBatch = null; + } + } + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java index d7891c9b9b0f..7e93d3df7731 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java @@ -13,13 +13,15 @@ */ package io.trino.execution.scheduler; +import com.google.common.util.concurrent.ListenableFuture; + import java.io.Closeable; import java.util.List; public interface TaskSource extends Closeable { - List getMoreTasks(); + ListenableFuture> getMoreTasks(); boolean isFinished(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index 248451b8eaa0..bf495e4c0a3d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -884,6 +884,51 @@ private void testCancellation(boolean abort) } } + @Test + public void testAsyncTaskSource() + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + SettableFuture> splitsFuture = SettableFuture.create(); + TestingTaskSourceFactory taskSourceFactory = new TestingTaskSourceFactory(Optional.of(CATALOG), splitsFuture, 1); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2, + 1); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); + + scheduler.schedule(); + assertBlocked(scheduler.isBlocked()); + + splitsFuture.set(createSplits(2)); + assertUnblocked(scheduler.isBlocked()); + scheduler.schedule(); + assertThat(remoteTaskFactory.getTasks()).hasSize(2); + remoteTaskFactory.getTasks().values().forEach(task -> { + assertThat(task.getSplits().values()).hasSize(2); + task.finish(); + }); + assertThat(scheduler.isFinished()).isTrue(); + } + } + private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( RemoteTaskFactory remoteTaskFactory, TaskSourceFactory taskSourceFactory, diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java index 8e72222d82ec..59134ca1d828 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -22,6 +22,8 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.DataSize; import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; @@ -60,6 +62,9 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Multimaps.toMultimap; import static com.google.common.collect.Streams.findLast; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.getDone; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.slice.SizeOf.estimatedSizeOf; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; @@ -94,7 +99,7 @@ public void testSingleDistributionTaskSource() assertFalse(taskSource.isFinished()); - List tasks = taskSource.getMoreTasks(); + List tasks = getFutureValue(taskSource.getMoreTasks()); assertThat(tasks).hasSize(1); assertTrue(taskSource.isFinished()); @@ -118,7 +123,7 @@ public void testArbitraryDistributionTaskSource() DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); assertFalse(taskSource.isFinished()); - List tasks = taskSource.getMoreTasks(); + List tasks = getFutureValue(taskSource.getMoreTasks()); assertThat(tasks).isEmpty(); assertTrue(taskSource.isFinished()); @@ -136,7 +141,7 @@ public void testArbitraryDistributionTaskSource() ImmutableListMultimap.of(), DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertTrue(taskSource.isFinished()); assertThat(tasks).hasSize(1); assertEquals(tasks, ImmutableList.of(new TaskDescriptor( @@ -153,7 +158,7 @@ public void testArbitraryDistributionTaskSource() ImmutableListMultimap.of(), DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertEquals(tasks, ImmutableList.of(new TaskDescriptor( 0, ImmutableListMultimap.of(), @@ -172,7 +177,7 @@ public void testArbitraryDistributionTaskSource() ImmutableListMultimap.of(), DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, @@ -199,7 +204,7 @@ public void testArbitraryDistributionTaskSource() ImmutableListMultimap.of(), DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, @@ -233,7 +238,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), ImmutableListMultimap.of(), DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, @@ -274,7 +279,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), replicatedSources, DataSize.of(3, BYTE), DataSize.of(4, GIGABYTE)); - tasks = taskSource.getMoreTasks(); + tasks = getFutureValue(taskSource.getMoreTasks()); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, @@ -313,7 +318,7 @@ public void testHashDistributionTaskSource() 0, DataSize.of(3, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of()); + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of()); assertTrue(taskSource.isFinished()); taskSource = createHashDistributionTaskSource( @@ -331,7 +336,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 0, DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor(0, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), @@ -362,7 +367,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 0, DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of( @@ -406,7 +411,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 0, DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of( @@ -452,7 +457,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), Optional.of(getTestingBucketNodeMap(4)), 0, DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of( @@ -491,7 +496,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), 2 * STANDARD_WEIGHT, DataSize.of(100, GIGABYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of( @@ -534,7 +539,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), 100 * STANDARD_WEIGHT, DataSize.of(100, BYTE)); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of( @@ -593,7 +598,8 @@ private static HashDistributionTaskSource createHashDistributionTaskSource( Optional.of(CATALOG), targetPartitionSplitWeight, targetPartitionSourceSize, - DataSize.of(4, GIGABYTE)); + DataSize.of(4, GIGABYTE), + directExecutor()); } @Test @@ -601,7 +607,7 @@ public void testSourceDistributionTaskSource() { TaskSource taskSource = createSourceDistributionTaskSource(ImmutableList.of(), ImmutableListMultimap.of(), 2, 0, 3 * STANDARD_WEIGHT, 1000); assertFalse(taskSource.isFinished()); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of()); + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of()); assertTrue(taskSource.isFinished()); Split split1 = createSplit(1); @@ -615,7 +621,7 @@ public void testSourceDistributionTaskSource() 0, 2 * STANDARD_WEIGHT, 1000); - assertEquals(taskSource.getMoreTasks(), ImmutableList.of(new TaskDescriptor( + assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(new TaskDescriptor( 0, ImmutableListMultimap.of(PLAN_NODE_1, split1), ImmutableListMultimap.of(), @@ -818,6 +824,67 @@ public void testSourceDistributionTaskSourceLastIncompleteTaskAlwaysCreated() } } + @Test + public void testSourceDistributionTaskSourceWithAsyncSplitSource() + { + SettableFuture> splitsFuture = SettableFuture.create(); + TaskSource taskSource = createSourceDistributionTaskSource( + new TestingSplitSource(CATALOG, splitsFuture, 0), + ImmutableListMultimap.of(), + 2, + 0, + 2 * STANDARD_WEIGHT, + 1000); + ListenableFuture> tasksFuture = taskSource.getMoreTasks(); + assertThat(tasksFuture).isNotDone(); + + splitsFuture.set(ImmutableList.of(createSplit(1), createSplit(2), createSplit(3))); + List tasks = getDone(tasksFuture); + assertThat(tasks).hasSize(1); + assertThat(tasks.get(0).getSplits()).hasSize(2); + + tasksFuture = taskSource.getMoreTasks(); + assertThat(tasksFuture).isDone(); + tasks = getDone(tasksFuture); + assertThat(tasks).hasSize(1); + assertThat(tasks.get(0).getSplits()).hasSize(1); + assertThat(taskSource.isFinished()).isTrue(); + } + + @Test + public void testHashDistributionTaskSourceWithAsyncSplitSource() + { + SettableFuture> splitsFuture1 = SettableFuture.create(); + SettableFuture> splitsFuture2 = SettableFuture.create(); + TaskSource taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_1, new TestingSplitSource(CATALOG, splitsFuture1, 0), + PLAN_NODE_2, new TestingSplitSource(CATALOG, splitsFuture2, 0)), + ImmutableListMultimap.of(), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + 1, + new int[] {0, 1, 2, 3}, + Optional.of(getTestingBucketNodeMap(4)), + 0, + DataSize.of(0, BYTE)); + ListenableFuture> tasksFuture = taskSource.getMoreTasks(); + assertThat(tasksFuture).isNotDone(); + + Split bucketedSplit1 = createBucketedSplit(0, 0); + Split bucketedSplit2 = createBucketedSplit(0, 2); + Split bucketedSplit3 = createBucketedSplit(0, 3); + splitsFuture1.set(ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)); + assertThat(tasksFuture).isNotDone(); + + Split bucketedSplit4 = createBucketedSplit(0, 1); + splitsFuture2.set(ImmutableList.of(bucketedSplit4)); + List tasks = getDone(tasksFuture); + assertThat(tasks).hasSize(4); + tasks.forEach(task -> assertThat(task.getSplits()).hasSize(1)); + assertThat(taskSource.isFinished()).isTrue(); + } + private static SourceDistributionTaskSource createSourceDistributionTaskSource( List splits, ListMultimap replicatedSources, @@ -855,7 +922,8 @@ private static SourceDistributionTaskSource createSourceDistributionTaskSource( minSplitsPerTask, splitWeightPerTask, maxSplitsPerTask, - DataSize.of(4, GIGABYTE)); + DataSize.of(4, GIGABYTE), + directExecutor()); } private static Split createSplit(int id, String... addresses) @@ -877,7 +945,7 @@ private List readAllTasks(TaskSource taskSource) { ImmutableList.Builder tasks = ImmutableList.builder(); while (!taskSource.isFinished()) { - tasks.addAll(taskSource.getMoreTasks()); + tasks.addAll(getFutureValue(taskSource.getMoreTasks())); } return tasks.build(); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java index 6a0692d19f9f..8fb3afae1be2 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java @@ -14,6 +14,7 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; @@ -25,15 +26,18 @@ import java.util.List; import java.util.Optional; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static java.util.Objects.requireNonNull; public class TestingSplitSource implements SplitSource { private final CatalogName catalogName; - private final Iterator splits; + private final ListenableFuture> splitsFuture; private int finishDelayRemainingIterations; + private Iterator splits; public TestingSplitSource(CatalogName catalogName, List splits) { @@ -41,9 +45,17 @@ public TestingSplitSource(CatalogName catalogName, List splits) } public TestingSplitSource(CatalogName catalogName, List splits, int finishDelayIterations) + { + this( + catalogName, + immediateFuture(ImmutableList.copyOf(requireNonNull(splits, "splits is null"))), + finishDelayIterations); + } + + public TestingSplitSource(CatalogName catalogName, ListenableFuture> splitsFuture, int finishDelayIterations) { this.catalogName = requireNonNull(catalogName, "catalogName is null"); - this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator(); + this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); this.finishDelayRemainingIterations = finishDelayIterations; } @@ -59,14 +71,19 @@ public ListenableFuture getNextBatch(ConnectorPartitionHandle partit if (isFinished()) { return immediateFuture(new SplitBatch(ImmutableList.of(), true)); } - ImmutableList.Builder result = ImmutableList.builder(); - for (int i = 0; i < maxSize; i++) { - if (!splits.hasNext()) { - break; - } - result.add(splits.next()); + + if (splits == null) { + return Futures.transform( + splitsFuture, + splits -> { + checkState(this.splits == null, "splits should be null"); + this.splits = splits.iterator(); + return populateSplitBatch(maxSize); + }, + directExecutor()); } - return immediateFuture(new SplitBatch(result.build(), isFinished())); + checkState(splitsFuture.isDone(), "splitsFuture should be completed"); + return immediateFuture(populateSplitBatch(maxSize)); } @Override @@ -77,7 +94,8 @@ public void close() @Override public boolean isFinished() { - return !splits.hasNext() && finishDelayRemainingIterations-- <= 0; + return (splits != null && !splits.hasNext()) + && finishDelayRemainingIterations-- <= 0; } @Override @@ -85,4 +103,16 @@ public Optional> getTableExecuteSplitsInfo() { return Optional.empty(); } + + private SplitBatch populateSplitBatch(int maxSize) + { + ImmutableList.Builder result = ImmutableList.builder(); + for (int i = 0; i < maxSize; i++) { + if (!splits.hasNext()) { + break; + } + result.add(splits.next()); + } + return new SplitBatch(result.build(), isFinished()); + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java index e1e66d806fab..38956531ab71 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java @@ -18,6 +18,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.trino.Session; import io.trino.connector.CatalogName; @@ -40,6 +42,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static java.util.Objects.requireNonNull; @@ -48,13 +52,18 @@ public class TestingTaskSourceFactory implements TaskSourceFactory { private final Optional catalog; - private final List splits; + private final ListenableFuture> splitsFuture; private final int tasksPerBatch; public TestingTaskSourceFactory(Optional catalog, List splits, int tasksPerBatch) + { + this(catalog, immediateFuture(ImmutableList.copyOf(requireNonNull(splits, "splits is null"))), tasksPerBatch); + } + + public TestingTaskSourceFactory(Optional catalog, ListenableFuture> splitsFuture, int tasksPerBatch) { this.catalog = requireNonNull(catalog, "catalog is null"); - this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); this.tasksPerBatch = tasksPerBatch; } @@ -73,7 +82,7 @@ public TaskSource create( return new TestingTaskSource( catalog, - splits, + splitsFuture, tasksPerBatch, getOnlyElement(partitionedSources), getHandlesForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)); @@ -99,32 +108,60 @@ public static class TestingTaskSource implements TaskSource { private final Optional catalogRequirement; - private final Iterator splits; + private final ListenableFuture> splitsFuture; private final int tasksPerBatch; private final PlanNodeId tableScanPlanNodeId; private final ListMultimap exchangeSourceHandles; private final AtomicInteger nextPartitionId = new AtomicInteger(); + private Iterator splits; public TestingTaskSource( Optional catalogRequirement, - List splits, + ListenableFuture> splitsFuture, int tasksPerBatch, PlanNodeId tableScanPlanNodeId, ListMultimap exchangeSourceHandles) { this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); - this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator(); + this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); this.tasksPerBatch = tasksPerBatch; this.tableScanPlanNodeId = requireNonNull(tableScanPlanNodeId, "tableScanPlanNodeId is null"); this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); } @Override - public List getMoreTasks() + public ListenableFuture> getMoreTasks() { checkState(!isFinished(), "already finished"); + if (splits == null) { + return Futures.transform( + splitsFuture, + loadedSplits -> { + checkState(this.splits == null, "splits should be null"); + splits = loadedSplits.iterator(); + return getTasksBatch(); + }, + directExecutor()); + } + checkState(splitsFuture.isDone(), "splitsFuture should be completed"); + return immediateFuture(getTasksBatch()); + } + + @Override + public boolean isFinished() + { + return splits != null && !splits.hasNext(); + } + + @Override + public void close() + { + } + + private List getTasksBatch() + { ImmutableList.Builder result = ImmutableList.builder(); for (int i = 0; i < tasksPerBatch; i++) { if (isFinished()) { @@ -138,19 +175,7 @@ public List getMoreTasks() new NodeRequirements(catalogRequirement, ImmutableSet.copyOf(split.getAddresses()), DataSize.of(4, GIGABYTE))); result.add(task); } - return result.build(); } - - @Override - public boolean isFinished() - { - return !splits.hasNext(); - } - - @Override - public void close() - { - } } }