Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.MoreFutures;
Expand Down Expand Up @@ -80,6 +81,7 @@
import static com.google.common.util.concurrent.Futures.allAsList;
import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.asVoid;
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
Expand Down Expand Up @@ -285,18 +287,29 @@ public synchronized void schedule()

while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) {
while (queuedPartitions.isEmpty() && pendingPartitions.size() < maxTasksWaitingForNodePerStage && !taskSource.isFinished()) {
List<TaskDescriptor> tasks = taskSource.getMoreTasks();
for (TaskDescriptor task : tasks) {
queuedPartitions.add(task.getPartitionId());
allPartitions.add(task.getPartitionId());
taskDescriptorStorage.put(stage.getStageId(), task);
sinkExchange.ifPresent(exchange -> {
ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId());
partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle);
});
}
if (taskSource.isFinished()) {
sinkExchange.ifPresent(Exchange::noMoreSinks);
ListenableFuture<Void> tasksPopulatedFuture = Futures.transform(
taskSource.getMoreTasks(),
tasks -> {
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
synchronized (this) {
for (TaskDescriptor task : tasks) {
queuedPartitions.add(task.getPartitionId());
allPartitions.add(task.getPartitionId());
taskDescriptorStorage.put(stage.getStageId(), task);
sinkExchange.ifPresent(exchange -> {
ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId());
partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle);
});
}
if (taskSource.isFinished()) {
sinkExchange.ifPresent(Exchange::noMoreSinks);
}
return null;
}
},
directExecutor());
if (!tasksPopulatedFuture.isDone()) {
blocked = tasksPopulatedFuture;
return;
}
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
*/
package io.trino.execution.scheduler;

import com.google.common.util.concurrent.ListenableFuture;

import java.io.Closeable;
import java.util.List;

public interface TaskSource
extends Closeable
{
List<TaskDescriptor> getMoreTasks();
ListenableFuture<List<TaskDescriptor>> getMoreTasks();

boolean isFinished();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,51 @@ private void testCancellation(boolean abort)
}
}

@Test
public void testAsyncTaskSource()
throws Exception
{
TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory();
SettableFuture<List<Split>> splitsFuture = SettableFuture.create();
TestingTaskSourceFactory taskSourceFactory = new TestingTaskSourceFactory(Optional.of(CATALOG), splitsFuture, 1);
TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(
NODE_1, ImmutableList.of(CATALOG),
NODE_2, ImmutableList.of(CATALOG)));
setupNodeAllocatorService(nodeSupplier);

TestingExchange sourceExchange1 = new TestingExchange(false);
TestingExchange sourceExchange2 = new TestingExchange(false);

try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) {
FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler(
remoteTaskFactory,
taskSourceFactory,
nodeAllocator,
TaskLifecycleListener.NO_OP,
Optional.empty(),
ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2),
2,
1);

sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1)));
sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1)));
assertUnblocked(scheduler.isBlocked());

scheduler.schedule();
assertBlocked(scheduler.isBlocked());

splitsFuture.set(createSplits(2));
assertUnblocked(scheduler.isBlocked());
scheduler.schedule();
assertThat(remoteTaskFactory.getTasks()).hasSize(2);
remoteTaskFactory.getTasks().values().forEach(task -> {
assertThat(task.getSplits().values()).hasSize(2);
task.finish();
});
assertThat(scheduler.isFinished()).isTrue();
}
}

private FaultTolerantStageScheduler createFaultTolerantTaskScheduler(
RemoteTaskFactory remoteTaskFactory,
TaskSourceFactory taskSourceFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.units.DataSize;
import io.trino.connector.CatalogName;
import io.trino.execution.Lifespan;
Expand Down Expand Up @@ -60,6 +62,9 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Multimaps.toMultimap;
import static com.google.common.collect.Streams.findLast;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.getDone;
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.units.DataSize.Unit.BYTE;
Expand Down Expand Up @@ -94,7 +99,7 @@ public void testSingleDistributionTaskSource()

assertFalse(taskSource.isFinished());

List<TaskDescriptor> tasks = taskSource.getMoreTasks();
List<TaskDescriptor> tasks = getFutureValue(taskSource.getMoreTasks());
assertThat(tasks).hasSize(1);
assertTrue(taskSource.isFinished());

Expand All @@ -118,7 +123,7 @@ public void testArbitraryDistributionTaskSource()
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
assertFalse(taskSource.isFinished());
List<TaskDescriptor> tasks = taskSource.getMoreTasks();
List<TaskDescriptor> tasks = getFutureValue(taskSource.getMoreTasks());
assertThat(tasks).isEmpty();
assertTrue(taskSource.isFinished());

Expand All @@ -136,7 +141,7 @@ public void testArbitraryDistributionTaskSource()
ImmutableListMultimap.of(),
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertTrue(taskSource.isFinished());
assertThat(tasks).hasSize(1);
assertEquals(tasks, ImmutableList.of(new TaskDescriptor(
Expand All @@ -153,7 +158,7 @@ public void testArbitraryDistributionTaskSource()
ImmutableListMultimap.of(),
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertEquals(tasks, ImmutableList.of(new TaskDescriptor(
0,
ImmutableListMultimap.of(),
Expand All @@ -172,7 +177,7 @@ public void testArbitraryDistributionTaskSource()
ImmutableListMultimap.of(),
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertEquals(tasks, ImmutableList.of(
new TaskDescriptor(
0,
Expand All @@ -199,7 +204,7 @@ public void testArbitraryDistributionTaskSource()
ImmutableListMultimap.of(),
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertEquals(tasks, ImmutableList.of(
new TaskDescriptor(
0,
Expand Down Expand Up @@ -233,7 +238,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)),
ImmutableListMultimap.of(),
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertEquals(tasks, ImmutableList.of(
new TaskDescriptor(
0,
Expand Down Expand Up @@ -274,7 +279,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)),
replicatedSources,
DataSize.of(3, BYTE),
DataSize.of(4, GIGABYTE));
tasks = taskSource.getMoreTasks();
tasks = getFutureValue(taskSource.getMoreTasks());
assertEquals(tasks, ImmutableList.of(
new TaskDescriptor(
0,
Expand Down Expand Up @@ -313,7 +318,7 @@ public void testHashDistributionTaskSource()
0,
DataSize.of(3, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of());
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of());
assertTrue(taskSource.isFinished());

taskSource = createHashDistributionTaskSource(
Expand All @@ -331,7 +336,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)),
0,
DataSize.of(0, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(0, ImmutableListMultimap.of(), ImmutableListMultimap.of(
PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1),
PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1),
Expand Down Expand Up @@ -362,7 +367,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)),
0,
DataSize.of(0, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(
0,
ImmutableListMultimap.of(
Expand Down Expand Up @@ -406,7 +411,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)),
0,
DataSize.of(0, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(
0,
ImmutableListMultimap.of(
Expand Down Expand Up @@ -452,7 +457,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)),
Optional.of(getTestingBucketNodeMap(4)),
0, DataSize.of(0, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(
0,
ImmutableListMultimap.of(
Expand Down Expand Up @@ -491,7 +496,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)),
2 * STANDARD_WEIGHT,
DataSize.of(100, GIGABYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(
0,
ImmutableListMultimap.of(
Expand Down Expand Up @@ -534,7 +539,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)),
100 * STANDARD_WEIGHT,
DataSize.of(100, BYTE));
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(
new TaskDescriptor(
0,
ImmutableListMultimap.of(
Expand Down Expand Up @@ -593,15 +598,16 @@ private static HashDistributionTaskSource createHashDistributionTaskSource(
Optional.of(CATALOG),
targetPartitionSplitWeight,
targetPartitionSourceSize,
DataSize.of(4, GIGABYTE));
DataSize.of(4, GIGABYTE),
directExecutor());
}

@Test
public void testSourceDistributionTaskSource()
{
TaskSource taskSource = createSourceDistributionTaskSource(ImmutableList.of(), ImmutableListMultimap.of(), 2, 0, 3 * STANDARD_WEIGHT, 1000);
assertFalse(taskSource.isFinished());
assertEquals(taskSource.getMoreTasks(), ImmutableList.of());
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of());
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
assertTrue(taskSource.isFinished());

Split split1 = createSplit(1);
Expand All @@ -615,7 +621,7 @@ public void testSourceDistributionTaskSource()
0,
2 * STANDARD_WEIGHT,
1000);
assertEquals(taskSource.getMoreTasks(), ImmutableList.of(new TaskDescriptor(
assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(new TaskDescriptor(
0,
ImmutableListMultimap.of(PLAN_NODE_1, split1),
ImmutableListMultimap.of(),
Expand Down Expand Up @@ -818,6 +824,67 @@ public void testSourceDistributionTaskSourceLastIncompleteTaskAlwaysCreated()
}
}

@Test
public void testSourceDistributionTaskSourceWithAsyncSplitSource()
{
SettableFuture<List<Split>> splitsFuture = SettableFuture.create();
TaskSource taskSource = createSourceDistributionTaskSource(
new TestingSplitSource(CATALOG, splitsFuture, 0),
ImmutableListMultimap.of(),
2,
0,
2 * STANDARD_WEIGHT,
1000);
ListenableFuture<List<TaskDescriptor>> tasksFuture = taskSource.getMoreTasks();
assertThat(tasksFuture).isNotDone();

splitsFuture.set(ImmutableList.of(createSplit(1), createSplit(2), createSplit(3)));
List<TaskDescriptor> tasks = getDone(tasksFuture);
assertThat(tasks).hasSize(1);
assertThat(tasks.get(0).getSplits()).hasSize(2);

tasksFuture = taskSource.getMoreTasks();
assertThat(tasksFuture).isDone();
tasks = getDone(tasksFuture);
assertThat(tasks).hasSize(1);
assertThat(tasks.get(0).getSplits()).hasSize(1);
assertThat(taskSource.isFinished()).isTrue();
}

@Test
public void testHashDistributionTaskSourceWithAsyncSplitSource()
{
SettableFuture<List<Split>> splitsFuture1 = SettableFuture.create();
SettableFuture<List<Split>> splitsFuture2 = SettableFuture.create();
TaskSource taskSource = createHashDistributionTaskSource(
ImmutableMap.of(
PLAN_NODE_1, new TestingSplitSource(CATALOG, splitsFuture1, 0),
PLAN_NODE_2, new TestingSplitSource(CATALOG, splitsFuture2, 0)),
ImmutableListMultimap.of(),
ImmutableListMultimap.of(
PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)),
1,
new int[] {0, 1, 2, 3},
Optional.of(getTestingBucketNodeMap(4)),
0,
DataSize.of(0, BYTE));
ListenableFuture<List<TaskDescriptor>> tasksFuture = taskSource.getMoreTasks();
assertThat(tasksFuture).isNotDone();

Split bucketedSplit1 = createBucketedSplit(0, 0);
Split bucketedSplit2 = createBucketedSplit(0, 2);
Split bucketedSplit3 = createBucketedSplit(0, 3);
splitsFuture1.set(ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3));
assertThat(tasksFuture).isNotDone();

Split bucketedSplit4 = createBucketedSplit(0, 1);
splitsFuture2.set(ImmutableList.of(bucketedSplit4));
List<TaskDescriptor> tasks = getDone(tasksFuture);
assertThat(tasks).hasSize(4);
tasks.forEach(task -> assertThat(task.getSplits()).hasSize(1));
assertThat(taskSource.isFinished()).isTrue();
}

private static SourceDistributionTaskSource createSourceDistributionTaskSource(
List<Split> splits,
ListMultimap<PlanNodeId, ExchangeSourceHandle> replicatedSources,
Expand Down Expand Up @@ -855,7 +922,8 @@ private static SourceDistributionTaskSource createSourceDistributionTaskSource(
minSplitsPerTask,
splitWeightPerTask,
maxSplitsPerTask,
DataSize.of(4, GIGABYTE));
DataSize.of(4, GIGABYTE),
directExecutor());
}

private static Split createSplit(int id, String... addresses)
Expand All @@ -877,7 +945,7 @@ private List<TaskDescriptor> readAllTasks(TaskSource taskSource)
{
ImmutableList.Builder<TaskDescriptor> tasks = ImmutableList.builder();
while (!taskSource.isFinished()) {
tasks.addAll(taskSource.getMoreTasks());
tasks.addAll(getFutureValue(taskSource.getMoreTasks()));
}
return tasks.build();
}
Expand Down
Loading