diff --git a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java index a088aa21429e..37511ee0b33c 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java +++ b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java @@ -91,6 +91,7 @@ public void addInput(ExchangeInput input) if (exchangeSource == null) { return; } + spoolingExchangeInput.getOutputSelector().ifPresent(exchangeSource::setOutputSelector); exchangeSource.addSourceHandles(spoolingExchangeInput.getExchangeSourceHandles()); } diff --git a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeInput.java b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeInput.java index 132613abefba..a2123fb88ee2 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeInput.java +++ b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeInput.java @@ -17,12 +17,15 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -32,11 +35,15 @@ public class SpoolingExchangeInput private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(SpoolingExchangeInput.class).instanceSize()); private final List exchangeSourceHandles; + private final Optional outputSelector; @JsonCreator - public SpoolingExchangeInput(@JsonProperty("exchangeSourceHandles") List exchangeSourceHandles) + public SpoolingExchangeInput( + @JsonProperty("exchangeSourceHandles") List exchangeSourceHandles, + @JsonProperty("outputSelector") Optional outputSelector) { this.exchangeSourceHandles = ImmutableList.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + this.outputSelector = requireNonNull(outputSelector, "outputSelector is null"); } @JsonProperty @@ -45,11 +52,18 @@ public List getExchangeSourceHandles() return exchangeSourceHandles; } + @JsonProperty + public Optional getOutputSelector() + { + return outputSelector; + } + @Override public String toString() { return toStringHelper(this) .add("exchangeSourceHandles", exchangeSourceHandles) + .add("outputSelector", outputSelector) .toString(); } @@ -57,6 +71,7 @@ public String toString() public long getRetainedSizeInBytes() { return INSTANCE_SIZE - + estimatedSizeOf(exchangeSourceHandles, ExchangeSourceHandle::getRetainedSizeInBytes); + + estimatedSizeOf(exchangeSourceHandles, ExchangeSourceHandle::getRetainedSizeInBytes) + + sizeOf(outputSelector, ExchangeSourceOutputSelector::getRetainedSizeInBytes); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java index 6081a2fa38f0..1da65d6a7df6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java @@ -15,14 +15,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.util.concurrent.Futures; +import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.concurrent.SetThreadName; import io.airlift.log.Logger; import io.airlift.stats.TimeStat; import io.airlift.units.Duration; import io.trino.Session; -import io.trino.exchange.ExchangeInput; import io.trino.exchange.SpoolingExchangeInput; import io.trino.execution.BasicStageStats; import io.trino.execution.NodeTaskMap; @@ -42,6 +41,7 @@ import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeManager; import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.SubPlan; @@ -62,7 +62,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Ticker.systemTicker; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Lists.reverse; import static io.airlift.concurrent.MoreFutures.addExceptionCallback; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; @@ -71,6 +70,7 @@ import static io.trino.SystemSessionProperties.getFaultTolerantExecutionPartitionCount; import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.execution.QueryState.FINISHING; +import static io.trino.execution.scheduler.Exchanges.getAllSourceHandles; import static io.trino.operator.RetryPolicy.TASK; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -214,7 +214,7 @@ private Scheduler createScheduler() session, getFaultTolerantExecutionPartitionCount(session)); - ImmutableList.Builder schedulers = ImmutableList.builder(); + Map schedulers = new HashMap<>(); Map exchanges = new HashMap<>(); NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(session); @@ -226,7 +226,6 @@ private Scheduler createScheduler() checkArgument(taskRetryAttemptsOverall >= 0, "taskRetryAttemptsOverall must be greater than or equal to 0: %s", taskRetryAttemptsOverall); AtomicInteger remainingTaskRetryAttemptsOverall = new AtomicInteger(taskRetryAttemptsOverall); - List outputStages = new ArrayList<>(); for (SqlStage stage : stagesInReverseTopologicalOrder) { PlanFragment fragment = stage.getFragment(); @@ -241,17 +240,16 @@ private Scheduler createScheduler() outputStage); exchanges.put(fragment.getId(), exchange); - if (outputStage) { - // output will be consumed by coordinator - outputStages.add(exchange); - } - ImmutableMap.Builder sourceExchanges = ImmutableMap.builder(); + ImmutableMap.Builder sourceSchedulers = ImmutableMap.builder(); for (SqlStage childStage : stageManager.getChildren(fragment.getId())) { PlanFragmentId childFragmentId = childStage.getFragment().getId(); Exchange sourceExchange = exchanges.get(childFragmentId); verify(sourceExchange != null, "exchange not found for fragment: %s", childFragmentId); sourceExchanges.put(childFragmentId, sourceExchange); + FaultTolerantStageScheduler sourceScheduler = schedulers.get(childFragmentId); + verify(sourceScheduler != null, "scheduler not found for fragment: %s", childFragmentId); + sourceSchedulers.put(childFragmentId, sourceScheduler); } FaultTolerantPartitioningScheme sourcePartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioning()); @@ -268,6 +266,7 @@ private Scheduler createScheduler() systemTicker(), exchange, sinkPartitioningScheme, + sourceSchedulers.buildOrThrow(), sourceExchanges.buildOrThrow(), sourcePartitioningScheme, remainingTaskRetryAttemptsOverall, @@ -275,38 +274,38 @@ private Scheduler createScheduler() maxTasksWaitingForNodePerStage, dynamicFilterService); - schedulers.add(scheduler); - } + schedulers.put(fragment.getId(), scheduler); - if (!stagesInReverseTopologicalOrder.isEmpty()) { - verify(!outputStages.isEmpty(), "coordinatorConsumedExchanges is empty"); - List>> futures = outputStages.stream() - .map(Exchange::getSourceHandles) - .map(Exchanges::getAllSourceHandles) - .collect(toImmutableList()); - ListenableFuture>> allFuture = Futures.allAsList(futures); - addSuccessCallback(allFuture, result -> { - List handles = result.stream() - .flatMap(List::stream) - .collect(toImmutableList()); - ImmutableList.Builder inputs = ImmutableList.builder(); - if (!handles.isEmpty()) { - inputs.add(new SpoolingExchangeInput(handles)); - } - queryStateMachine.updateInputsForQueryResults(inputs.build(), true); - }); - addExceptionCallback(allFuture, queryStateMachine::transitionToFailed); + if (outputStage) { + ListenableFuture> sourceHandles = getAllSourceHandles(exchange.getSourceHandles()); + addSuccessCallback(sourceHandles, handles -> { + try { + ExchangeSourceOutputSelector.Builder selector = ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchange.getId())); + Map successfulAttempts = scheduler.getSuccessfulAttempts(); + successfulAttempts.forEach((taskPartitionId, attemptId) -> + selector.include(exchange.getId(), taskPartitionId, attemptId)); + selector.setPartitionCount(exchange.getId(), successfulAttempts.size()); + selector.setFinal(); + SpoolingExchangeInput input = new SpoolingExchangeInput(handles, Optional.of(selector.build())); + queryStateMachine.updateInputsForQueryResults(ImmutableList.of(input), true); + } + catch (Throwable t) { + queryStateMachine.transitionToFailed(t); + } + }); + addExceptionCallback(sourceHandles, queryStateMachine::transitionToFailed); + } } return new Scheduler( queryStateMachine, - schedulers.build(), + ImmutableList.copyOf(schedulers.values()), stageManager, schedulerStats, nodeAllocator); } catch (Throwable t) { - for (FaultTolerantStageScheduler scheduler : schedulers.build()) { + for (FaultTolerantStageScheduler scheduler : schedulers.values()) { try { scheduler.abort(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index ccabb5d23a50..b1699317ba94 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -18,6 +18,7 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; @@ -29,6 +30,7 @@ import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.trino.Session; +import io.trino.exchange.SpoolingExchangeInput; import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.RemoteTask; import io.trino.execution.SqlStage; @@ -41,13 +43,17 @@ import io.trino.execution.scheduler.PartitionMemoryEstimator.MemoryRequirements; import io.trino.failuredetector.FailureDetector; import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; import io.trino.server.DynamicFilterService; import io.trino.spi.ErrorCode; import io.trino.spi.TrinoException; import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeSinkHandle; import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; +import io.trino.split.RemoteSplit; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RemoteSourceNode; @@ -67,12 +73,14 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.propagateIfPossible; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.Futures.allAsList; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; @@ -87,6 +95,7 @@ import static io.trino.execution.scheduler.ErrorCodes.isOutOfMemoryError; import static io.trino.execution.scheduler.Exchanges.getAllSourceHandles; import static io.trino.failuredetector.FailureDetector.State.GONE; +import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; import static io.trino.spi.ErrorType.EXTERNAL; import static io.trino.spi.ErrorType.INTERNAL_ERROR; import static io.trino.spi.ErrorType.USER_ERROR; @@ -115,6 +124,7 @@ public class FaultTolerantStageScheduler private final Exchange sinkExchange; private final FaultTolerantPartitioningScheme sinkPartitioningScheme; + private final Map sourceSchedulers; private final Map sourceExchanges; private final FaultTolerantPartitioningScheme sourcePartitioningScheme; @@ -153,17 +163,21 @@ public class FaultTolerantStageScheduler @GuardedBy("this") private final Set allPartitions = new HashSet<>(); @GuardedBy("this") + private boolean noMorePartitions; + @GuardedBy("this") private final Queue queuedPartitions = new ArrayDeque<>(); @GuardedBy("this") private final Queue pendingPartitions = new ArrayDeque<>(); @GuardedBy("this") - private final Set finishedPartitions = new HashSet<>(); + private final Map finishedPartitions = new HashMap<>(); @GuardedBy("this") private final AtomicInteger remainingRetryAttemptsOverall; @GuardedBy("this") private final Map remainingAttemptsPerTask = new HashMap<>(); @GuardedBy("this") private final Map partitionMemoryRequirements = new HashMap<>(); + @GuardedBy("this") + private Multimap outputSelectorSplits; private final DynamicFilterService dynamicFilterService; @@ -185,6 +199,7 @@ public FaultTolerantStageScheduler( Ticker ticker, Exchange sinkExchange, FaultTolerantPartitioningScheme sinkPartitioningScheme, + Map sourceSchedulers, Map sourceExchanges, FaultTolerantPartitioningScheme sourcePartitioningScheme, AtomicInteger remainingRetryAttemptsOverall, @@ -203,7 +218,15 @@ public FaultTolerantStageScheduler( this.futureCompletor = requireNonNull(futureCompletor, "futureCompletor is null"); this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null"); this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null"); - this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null")); + Set sourceFragments = stage.getFragment().getRemoteSourceNodes().stream() + .flatMap(remoteSource -> remoteSource.getSourceFragmentIds().stream()) + .collect(toImmutableSet()); + requireNonNull(sourceSchedulers, "sourceSchedulers is null"); + checkArgument(sourceSchedulers.keySet().containsAll(sourceFragments), "sourceSchedulers map is incomplete"); + this.sourceSchedulers = ImmutableMap.copyOf(sourceSchedulers); + requireNonNull(sourceExchanges, "sourceExchanges is null"); + checkArgument(sourceExchanges.keySet().containsAll(sourceFragments), "sourceExchanges map is incomplete"); + this.sourceExchanges = ImmutableMap.copyOf(sourceExchanges); this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); this.remainingRetryAttemptsOverall = requireNonNull(remainingRetryAttemptsOverall, "remainingRetryAttemptsOverall is null"); this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask; @@ -291,6 +314,10 @@ public synchronized void schedule() if (taskSource.isFinished()) { dynamicFilterService.stageCannotScheduleMoreTasks(stage.getStageId(), 0, allPartitions.size()); sinkExchange.noMoreSinks(); + noMorePartitions = true; + } + if (noMorePartitions && finishedPartitions.keySet().containsAll(allPartitions)) { + sinkExchange.allRequiredSinksFinished(); } return null; } @@ -351,6 +378,7 @@ public synchronized void schedule() } } + @GuardedBy("this") private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryRequirements memoryRequirements) { Optional taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition); @@ -376,13 +404,18 @@ private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryR .iterator()) .build(); + createOutputSelectorSplitsIfNecessary(); + RemoteTask task = stage.createTask( node, partition, attemptId, sinkPartitioningScheme.getBucketToPartitionMap(), outputBuffers, - taskDescriptor.getSplits(), + ImmutableListMultimap.builder() + .putAll(outputSelectorSplits) + .putAll(taskDescriptor.getSplits()) + .build(), allSourcePlanNodeIds, Optional.of(memoryRequirements.getRequiredMemory())).orElseThrow(() -> new VerifyException("stage execution is expected to be active")); @@ -400,6 +433,35 @@ private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryR task.start(); } + @GuardedBy("this") + private void createOutputSelectorSplitsIfNecessary() + { + if (outputSelectorSplits != null) { + return; + } + + ImmutableListMultimap.Builder selectors = ImmutableListMultimap.builder(); + for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { + List sourceFragmentIds = remoteSource.getSourceFragmentIds(); + Set sourceExchangeIds = sourceExchanges.entrySet().stream() + .filter(entry -> sourceFragmentIds.contains(entry.getKey())) + .map(entry -> entry.getValue().getId()) + .collect(toImmutableSet()); + ExchangeSourceOutputSelector.Builder selector = ExchangeSourceOutputSelector.builder(sourceExchangeIds); + for (PlanFragmentId sourceFragment : sourceFragmentIds) { + FaultTolerantStageScheduler sourceScheduler = sourceSchedulers.get(sourceFragment); + Exchange sourceExchange = sourceExchanges.get(sourceFragment); + Map successfulAttempts = sourceScheduler.getSuccessfulAttempts(); + successfulAttempts.forEach((taskPartitionId, attemptId) -> + selector.include(sourceExchange.getId(), taskPartitionId, attemptId)); + selector.setPartitionCount(sourceExchange.getId(), successfulAttempts.size()); + } + selector.setFinal(); + selectors.put(remoteSource.getId(), new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(), Optional.of(selector.build()))))); + } + outputSelectorSplits = selectors.build(); + } + public synchronized boolean isFinished() { return failure == null && @@ -407,7 +469,12 @@ public synchronized boolean isFinished() taskSource.isFinished() && tasksPopulatedFuture.isDone() && queuedPartitions.isEmpty() && - finishedPartitions.containsAll(allPartitions); + allPartitions.stream().allMatch(finishedPartitions::containsKey); + } + + public synchronized Map getSuccessfulAttempts() + { + return ImmutableMap.copyOf(finishedPartitions); } public void cancel() @@ -550,13 +617,16 @@ private void updateTaskStatus(TaskStatus taskStatus, ExchangeSinkHandle exchange int partitionId = taskId.getPartitionId(); - if (!finishedPartitions.contains(partitionId) && !closed) { + if (!finishedPartitions.containsKey(partitionId) && !closed) { MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId); verify(memoryLimits != null); switch (state) { case FINISHED: - finishedPartitions.add(partitionId); + finishedPartitions.put(partitionId, taskId.getAttemptId()); sinkExchange.sinkFinished(exchangeSinkHandle, taskId.getAttemptId()); + if (noMorePartitions && finishedPartitions.keySet().containsAll(allPartitions)) { + sinkExchange.allRequiredSinksFinished(); + } partitionToRemoteTaskMap.get(partitionId).forEach(RemoteTask::abort); partitionMemoryEstimator.registerPartitionFinished(session, memoryLimits, taskStatus.getPeakMemoryReservation(), true, Optional.empty()); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index 5b5ee26c955e..ab9e2eadf5e3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -921,7 +921,7 @@ static ListMultimap createRemoteSplits(ListMultimap exchangeSourceHandles) { - return new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.copyOf(exchangeSourceHandles)))); + return new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.copyOf(exchangeSourceHandles), Optional.empty()))); } private static class LoadedSplits diff --git a/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java index 822b182ee326..3b4624619dea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java @@ -40,6 +40,7 @@ import io.trino.spi.exchange.ExchangeSink; import io.trino.spi.exchange.ExchangeSinkHandle; import io.trino.spi.exchange.ExchangeSource; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.NotThreadSafe; @@ -665,6 +666,7 @@ public synchronized OutputSource createOutputSource(Set selectedTasks) ListenableFuture exchangeSourceFuture = FluentFuture.from(toListenableFuture(exchangeSink.finish())) .transformAsync(ignored -> { exchange.sinkFinished(sinkHandle, 0); + exchange.allRequiredSinksFinished(); synchronized (this) { exchangeSink = null; sinkHandle = null; @@ -674,6 +676,11 @@ public synchronized OutputSource createOutputSource(Set selectedTasks) .transform(handles -> { ExchangeSource source = exchangeManager.createSource(); try { + source.setOutputSelector(ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchangeId)) + .include(exchangeId, 0, 0) + .setPartitionCount(exchangeId, 1) + .setFinal() + .build()); source.addSourceHandles(handles); source.noMoreSourceHandles(); return source; diff --git a/core/trino-main/src/test/java/io/trino/exchange/TestExchangeSourceOutputSelector.java b/core/trino-main/src/test/java/io/trino/exchange/TestExchangeSourceOutputSelector.java new file mode 100644 index 000000000000..7e9a2a5a3a2c --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/exchange/TestExchangeSourceOutputSelector.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.exchange; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; +import io.airlift.slice.Slice; +import io.trino.server.SliceSerialization.SliceDeserializer; +import io.trino.server.SliceSerialization.SliceSerializer; +import io.trino.spi.exchange.ExchangeId; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.EXCLUDED; +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.INCLUDED; +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.UNKNOWN; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestExchangeSourceOutputSelector +{ + private static final ExchangeId EXCHANGE_ID_1 = new ExchangeId("exchange_1"); + private static final ExchangeId EXCHANGE_ID_2 = new ExchangeId("exchange_2"); + + private JsonCodec codec; + + @BeforeClass + public void setup() + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + objectMapperProvider.setJsonSerializers(ImmutableMap.of(Slice.class, new SliceSerializer())); + objectMapperProvider.setJsonDeserializers(ImmutableMap.of(Slice.class, new SliceDeserializer())); + codec = new JsonCodecFactory(objectMapperProvider).jsonCodec(ExchangeSourceOutputSelector.class); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + codec = null; + } + + @Test + public void testEmpty() + { + { + ExchangeSourceOutputSelector selector = serializeDeserialize(ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1, EXCHANGE_ID_2)) + .build()); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 100, 1), UNKNOWN); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 21, 2), UNKNOWN); + assertFalse(selector.isFinal()); + } + + { + ExchangeSourceOutputSelector selector = serializeDeserialize(ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1, EXCHANGE_ID_2)) + .setPartitionCount(EXCHANGE_ID_1, 0) + .setPartitionCount(EXCHANGE_ID_2, 0) + .setFinal() + .build()); + assertTrue(selector.isFinal()); + // final selector should have selection set for all partitions + assertThatThrownBy(() -> selector.getSelection(EXCHANGE_ID_1, 100, 1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("selection not found for exchangeId %s, taskPartitionId %s".formatted(EXCHANGE_ID_1, 100)); + } + } + + @Test + public void testNonFinal() + { + ExchangeSourceOutputSelector selector = serializeDeserialize(ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1, EXCHANGE_ID_2)) + .include(EXCHANGE_ID_1, 21, 2) + .exclude(EXCHANGE_ID_2, 100) + .build()); + // ensure exchange id is taken into account + assertEquals(selector.getSelection(EXCHANGE_ID_1, 100, 1), UNKNOWN); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 100, 1), EXCLUDED); + // all attempts of a given task must be excluded + assertEquals(selector.getSelection(EXCHANGE_ID_2, 100, 2), EXCLUDED); + // ensure exchange id is taken into account + assertEquals(selector.getSelection(EXCHANGE_ID_2, 21, 2), UNKNOWN); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 21, 2), INCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 21, 1), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 1, 2), UNKNOWN); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 200, 2), UNKNOWN); + assertFalse(selector.isFinal()); + } + + @Test + public void testFinal() + { + // partition count must be set + assertThatThrownBy(() -> ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1)) + .include(EXCHANGE_ID_1, 1, 2) + .setFinal() + .build()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("partition count is missing for exchange: %s".formatted(EXCHANGE_ID_1)); + + ExchangeSourceOutputSelector selector = serializeDeserialize(ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1, EXCHANGE_ID_2)) + .include(EXCHANGE_ID_1, 0, 1) + .exclude(EXCHANGE_ID_1, 1) + .include(EXCHANGE_ID_1, 2, 0) + .exclude(EXCHANGE_ID_2, 0) + .setPartitionCount(EXCHANGE_ID_1, 3) + .setPartitionCount(EXCHANGE_ID_2, 1) + .setFinal() + .build()); + + assertEquals(selector.getSelection(EXCHANGE_ID_1, 0, 1), INCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 0, 2), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 1, 0), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 1, 2), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 2, 0), INCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_1, 2, 2), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 0, 1), EXCLUDED); + assertEquals(selector.getSelection(EXCHANGE_ID_2, 0, 0), EXCLUDED); + + assertThatThrownBy(() -> selector.getSelection(EXCHANGE_ID_1, 100, 1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("selection not found for exchangeId %s, taskPartitionId %s".formatted(EXCHANGE_ID_1, 100)); + } + + @Test + public void testBasicTransitions() + { + ExchangeSourceOutputSelector.Builder builder = ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1, EXCHANGE_ID_2)); + ExchangeSourceOutputSelector partialVersion0 = builder.build(); + builder.include(EXCHANGE_ID_1, 0, 0); + ExchangeSourceOutputSelector partialVersion1 = builder.build(); + builder.exclude(EXCHANGE_ID_1, 1); + ExchangeSourceOutputSelector partialVersion2 = builder.build(); + builder.setPartitionCount(EXCHANGE_ID_1, 2); + builder.setPartitionCount(EXCHANGE_ID_2, 0); + builder.setFinal(); + ExchangeSourceOutputSelector finalVersion1 = builder.build(); + ExchangeSourceOutputSelector finalVersion2 = builder.build(); + + // legitimate transitions + partialVersion0.checkValidTransition(partialVersion1); + partialVersion1.checkValidTransition(partialVersion2); + partialVersion2.checkValidTransition(finalVersion1); + + assertThatThrownBy(() -> partialVersion1.checkValidTransition(partialVersion0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition to the same or an older version"); + assertThatThrownBy(() -> partialVersion2.checkValidTransition(partialVersion0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition to the same or an older version"); + assertThatThrownBy(() -> partialVersion2.checkValidTransition(partialVersion1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition to the same or an older version"); + assertThatThrownBy(() -> finalVersion2.checkValidTransition(finalVersion1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition to the same or an older version"); + assertThatThrownBy(() -> finalVersion2.checkValidTransition(partialVersion0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition to the same or an older version"); + assertThatThrownBy(() -> finalVersion1.checkValidTransition(finalVersion2)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid transition from final selector"); + } + + @Test + public void testIncompatibleTransitions() + { + ExchangeSourceOutputSelector.Builder builder = ExchangeSourceOutputSelector.builder(ImmutableSet.of(EXCHANGE_ID_1)); + builder.include(EXCHANGE_ID_1, 0, 0); + assertThatThrownBy(() -> builder.include(EXCHANGE_ID_1, 0, 1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("decision for partition 0 is already made: 0"); + assertThatThrownBy(() -> builder.exclude(EXCHANGE_ID_1, 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("decision for partition 0 is already made: 0"); + } + + private ExchangeSourceOutputSelector serializeDeserialize(ExchangeSourceOutputSelector selector) + { + return codec.fromJson(codec.toJson(selector)); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index dee966f3229d..dbc83a99d4a4 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -67,6 +67,7 @@ import java.net.URI; import java.time.Duration; +import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -76,6 +77,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.cycle; import static com.google.common.collect.Iterables.limit; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -276,6 +278,7 @@ public void testHappyPath() tasks = remoteTaskFactory.getTasks(); assertThat(tasks).hasSize(6); assertThat(tasks).containsKey(getTaskId(4, 0)); + assertTrue(sinkExchange.isNoMoreSinks()); // not finished yet, will be finished when all tasks succeed assertFalse(scheduler.isFinished()); @@ -302,6 +305,8 @@ public void testHappyPath() new TestingExchangeSinkHandle(3), new TestingExchangeSinkHandle(4)); + assertTrue(sinkExchange.isAllRequiredSinksFinished()); + assertTrue(scheduler.isFinished()); } } @@ -819,7 +824,9 @@ public void testAsyncTaskSource() scheduler.schedule(); assertThat(remoteTaskFactory.getTasks()).hasSize(2); remoteTaskFactory.getTasks().values().forEach(task -> { - assertThat(task.getSplits().values()).hasSize(2); + Collection splits = task.getSplits().values(); + // 2 normal splits + 1 split containing an output selector + assertThat(splits).hasSize(3); task.finish(); }); assertThat(scheduler.isFinished()).isTrue(); @@ -856,6 +863,11 @@ public boolean isFinished() public void close() {} }; try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + TestingExchange sourceExchange1 = new TestingExchange(); + sourceExchange1.setSourceHandles(ImmutableList.of()); + TestingExchange sourceExchange2 = new TestingExchange(); + sourceExchange2.setSourceHandles(ImmutableList.of()); + TestingExchange sinkExchange = new TestingExchange(); FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( remoteTaskFactory, (session, fragment, exchangeSourceHandles, getSplitTimeRecorder, bucketToPartition) -> { @@ -863,8 +875,10 @@ public void close() {} return taskSource; }, nodeAllocator, - new TestingExchange(), - ImmutableMap.of(), + sinkExchange, + ImmutableMap.of( + SOURCE_FRAGMENT_ID_1, sourceExchange1, + SOURCE_FRAGMENT_ID_2, sourceExchange2), 1, 1); @@ -885,6 +899,8 @@ public void close() {} future.set(ImmutableList.of()); assertTrue(scheduler.isFinished()); + assertTrue(sinkExchange.isNoMoreSinks()); + assertTrue(sinkExchange.isAllRequiredSinksFinished()); } } @@ -920,9 +936,60 @@ private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( { TaskDescriptorStorage taskDescriptorStorage = new TaskDescriptorStorage(DataSize.of(10, MEGABYTE)); taskDescriptorStorage.initialize(SESSION.getQueryId()); + DynamicFilterService dynamicFilterService = new DynamicFilterService(PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), PLANNER_CONTEXT.getTypeOperators(), new DynamicFilterConfig()); + return createStageScheduler( + session, + createSqlStage(createIntermediatePlanFragment(), remoteTaskFactory), + nodeAllocator, + retryAttempts, + maxTasksWaitingForNodePerStage, + taskDescriptorStorage, + taskSourceFactory, + dynamicFilterService, + sinkExchange, + sourceExchanges.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> { + FaultTolerantStageScheduler sourceScheduler = createStageScheduler( + session, + createSqlStage(createLeafPlanFragment(entry.getKey()), remoteTaskFactory), + nodeAllocator, + retryAttempts, + maxTasksWaitingForNodePerStage, + taskDescriptorStorage, + new TestingTaskSourceFactory(Optional.empty(), ImmutableList.of(), 1), + dynamicFilterService, + entry.getValue(), + ImmutableMap.of(), + ImmutableMap.of()); + while (!sourceScheduler.isFinished()) { + try { + sourceScheduler.schedule(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + return sourceScheduler; + })), + sourceExchanges); + } + + private FaultTolerantStageScheduler createStageScheduler( + Session session, + SqlStage stage, + NodeAllocator nodeAllocator, + int retryAttempts, + int maxTasksWaitingForNodePerStage, + TaskDescriptorStorage taskDescriptorStorage, + TaskSourceFactory taskSourceFactory, + DynamicFilterService dynamicFilterService, + Exchange sinkExchange, + Map sourceSchedulers, + Map sourceExchanges) + { + FaultTolerantPartitioningScheme partitioningScheme = new FaultTolerantPartitioningScheme(3, Optional.empty(), Optional.empty(), Optional.empty()); return new FaultTolerantStageScheduler( session, - createSqlStage(remoteTaskFactory), + stage, new NoOpFailureDetector(), taskSourceFactory, nodeAllocator, @@ -932,18 +999,18 @@ private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( futureCompletor, ticker, sinkExchange, - new FaultTolerantPartitioningScheme(3, Optional.empty(), Optional.empty(), Optional.empty()), + partitioningScheme, + sourceSchedulers, sourceExchanges, - new FaultTolerantPartitioningScheme(3, Optional.empty(), Optional.empty(), Optional.empty()), + partitioningScheme, new AtomicInteger(retryAttempts), retryAttempts, maxTasksWaitingForNodePerStage, - new DynamicFilterService(PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), PLANNER_CONTEXT.getTypeOperators(), new DynamicFilterConfig())); + dynamicFilterService); } - private SqlStage createSqlStage(RemoteTaskFactory remoteTaskFactory) + private SqlStage createSqlStage(PlanFragment fragment, RemoteTaskFactory remoteTaskFactory) { - PlanFragment fragment = createPlanFragment(); return SqlStage.createSqlStage( STAGE_ID, fragment, @@ -956,7 +1023,7 @@ private SqlStage createSqlStage(RemoteTaskFactory remoteTaskFactory) new SplitSchedulerStats()); } - private PlanFragment createPlanFragment() + private PlanFragment createIntermediatePlanFragment() { Symbol probeColumnSymbol = new Symbol("probe_column"); Symbol buildColumnSymbol = new Symbol("build_column"); @@ -1003,6 +1070,29 @@ private PlanFragment createPlanFragment() Optional.empty()); } + private PlanFragment createLeafPlanFragment(PlanFragmentId fragmentId) + { + Symbol outputColumn = new Symbol("output_column"); + return new PlanFragment( + fragmentId, + new TableScanNode( + TABLE_SCAN_NODE_ID, + TEST_TABLE_HANDLE, + ImmutableList.of(outputColumn), + ImmutableMap.of(outputColumn, new TestingColumnHandle("column")), + TupleDomain.none(), + Optional.empty(), + false, + Optional.empty()), + ImmutableMap.of(outputColumn, VARCHAR), + SOURCE_DISTRIBUTION, + ImmutableList.of(TABLE_SCAN_NODE_ID), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(outputColumn)), + StatsAndCosts.empty(), + ImmutableList.of(), + Optional.empty()); + } + private static TestingTaskSourceFactory createTaskSourceFactory(int splitCount, int taskPerBatch) { return new TestingTaskSourceFactory(Optional.of(TEST_CATALOG_HANDLE), createSplits(splitCount), taskPerBatch); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java index ead694f8c4c1..3e7b39bdda07 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java @@ -200,7 +200,7 @@ private static TaskDescriptor createTaskDescriptor(int partitionId, DataSize ret partitionId, ImmutableListMultimap.of( new PlanNodeId("1"), - new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(new TestingExchangeSourceHandle(retainedSize.toBytes())))))), + new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(new TestingExchangeSourceHandle(retainedSize.toBytes())), Optional.empty())))), new NodeRequirements(catalog, ImmutableSet.of())); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java index 2a274fa32d7a..6c260ef06d36 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeSinkHandle; import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.spi.exchange.ExchangeSourceHandle; @@ -30,16 +31,25 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.Sets.newConcurrentHashSet; +import static io.trino.spi.exchange.ExchangeId.createRandomExchangeId; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class TestingExchange implements Exchange { + private final ExchangeId exchangeId = createRandomExchangeId(); private final Set finishedSinks = newConcurrentHashSet(); private final Set allSinks = newConcurrentHashSet(); private final AtomicBoolean noMoreSinks = new AtomicBoolean(); private final CompletableFuture> sourceHandles = new CompletableFuture<>(); + private final AtomicBoolean allRequiredSinksFinished = new AtomicBoolean(); + + @Override + public ExchangeId getId() + { + return exchangeId; + } @Override public ExchangeSinkHandle addSink(int taskPartitionId) @@ -78,6 +88,17 @@ public void sinkFinished(ExchangeSinkHandle sinkHandle, int taskAttemptId) finishedSinks.add((TestingExchangeSinkHandle) sinkHandle); } + @Override + public void allRequiredSinksFinished() + { + allRequiredSinksFinished.set(true); + } + + public boolean isAllRequiredSinksFinished() + { + return allRequiredSinksFinished.get(); + } + public Set getFinishedSinkHandles() { return ImmutableSet.copyOf(finishedSinks); diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java index dd1b9183af29..070990b34f78 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java @@ -24,6 +24,11 @@ public interface Exchange extends Closeable { + /** + * Get id of this exchange + */ + ExchangeId getId(); + /** * Registers a new sink * @@ -75,6 +80,12 @@ public interface Exchange */ void sinkFinished(ExchangeSinkHandle sinkHandle, int taskAttemptId); + /** + * Called by the engine when all required sinks finished successfully. + * While some source tasks may still be running and writing to their sinks the data written to these sinks could be safely ignored after this method is invoked. + */ + void allRequiredSinksFinished(); + /** * Returns an {@link ExchangeSourceHandleSource} instance to be used to enumerate {@link ExchangeSourceHandle}s. * diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeId.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeId.java index 95d977cae916..52d7eb7f8f33 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeId.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeId.java @@ -16,16 +16,20 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; import io.trino.spi.Experimental; +import org.openjdk.jol.info.ClassLayout; import java.util.Objects; import java.util.regex.Pattern; +import static io.airlift.slice.SizeOf.estimatedSizeOf; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; @Experimental(eta = "2023-01-01") public class ExchangeId { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(ExchangeId.class).instanceSize(); + private static final Pattern ID_PATTERN = Pattern.compile("[a-zA-Z0-9_-]+"); private final String id; @@ -75,4 +79,9 @@ public String toString() { return id; } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + estimatedSizeOf(id); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java index bdbe06089c8c..6e2db4e7fff9 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java @@ -48,6 +48,22 @@ public interface ExchangeSource */ void noMoreSourceHandles(); + /** + * Called by the engine to provide information about what source task output must be included + * and what must be skipped. + *

+ * This method can be called multiple times and out of order. + * Only a newest version (see {@link ExchangeSourceOutputSelector#getVersion()}) must be taken into account. + * Updates with an older version must be ignored. + *

+ * The information provided by the {@link ExchangeSourceOutputSelector} is incremental and decisions + * for some partitions could be missing. The implementation is free to speculate. + *

+ * The final selector is guaranteed to contain a decision for each source partition (see {@link ExchangeSourceOutputSelector#isFinal()}). + * If decision is made for a given partition in some version the decision is guaranteed not to change in newer versions. + */ + void setOutputSelector(ExchangeSourceOutputSelector selector); + /** * Returns a future that will be completed when the exchange source becomes * unblocked. If the exchange source is not blocked, this method should return diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java new file mode 100644 index 000000000000..f8e4837beec1 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java @@ -0,0 +1,344 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.slice.SizeOf; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.openjdk.jol.info.ClassLayout; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.EXCLUDED; +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.INCLUDED; +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.UNKNOWN; +import static java.lang.Math.max; +import static java.util.Arrays.fill; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; +import static java.util.stream.Collectors.toUnmodifiableMap; + +public class ExchangeSourceOutputSelector +{ + private static final long INSTANCE_SIZE = ClassLayout.parseClass(ExchangeSourceOutputSelector.class).instanceSize(); + + private final int version; + private final Map values; + private final boolean finalSelector; + + // visible for Jackson + @JsonCreator + public ExchangeSourceOutputSelector( + @JsonProperty("version") int version, + @JsonProperty("values") Map values, + @JsonProperty("finalSelector") boolean finalSelector) + { + this.version = version; + this.values = Map.copyOf(requireNonNull(values, "values is null")); + this.finalSelector = finalSelector; + } + + @JsonProperty + public int getVersion() + { + return version; + } + + // visible for Jackson + @JsonProperty + public Map getValues() + { + return values; + } + + @JsonProperty("finalSelector") + public boolean isFinal() + { + return finalSelector; + } + + public Selection getSelection(ExchangeId exchangeId, int taskPartitionId, int attemptId) + { + requireNonNull(exchangeId, "exchangeId is null"); + if (taskPartitionId < 0) { + throw new IllegalArgumentException("unexpected taskPartitionId: " + taskPartitionId); + } + if (attemptId < 0 || attemptId > Byte.MAX_VALUE) { + throw new IllegalArgumentException("unexpected attemptId: " + attemptId); + } + Slice exchangeValues = values.get(exchangeId); + if (exchangeValues == null) { + throwIfFinal(exchangeId, taskPartitionId); + return UNKNOWN; + } + if (exchangeValues.length() <= taskPartitionId) { + throwIfFinal(exchangeId, taskPartitionId); + return UNKNOWN; + } + byte selectedAttempt = exchangeValues.getByte(taskPartitionId); + if (selectedAttempt == UNKNOWN.getValue()) { + throwIfFinal(exchangeId, taskPartitionId); + return UNKNOWN; + } + if (selectedAttempt == EXCLUDED.getValue()) { + return EXCLUDED; + } + if (selectedAttempt < 0) { + throw new IllegalArgumentException("unexpected selectedAttempt: " + selectedAttempt); + } + return selectedAttempt == attemptId ? INCLUDED : EXCLUDED; + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + + SizeOf.estimatedSizeOf(values, ExchangeId::getRetainedSizeInBytes, Slice::getRetainedSize); + } + + public void checkValidTransition(ExchangeSourceOutputSelector newSelector) + { + if (this.version >= newSelector.version) { + throw new IllegalArgumentException("Invalid transition to the same or an older version"); + } + + if (this.isFinal()) { + throw new IllegalArgumentException("Invalid transition from final selector"); + } + + Set exchangeIds = new HashSet<>(); + exchangeIds.addAll(this.values.keySet()); + exchangeIds.addAll(newSelector.values.keySet()); + + for (ExchangeId exchangeId : exchangeIds) { + int taskPartitionCount = max(this.getPartitionCount(exchangeId), newSelector.getPartitionCount(exchangeId)); + for (int taskPartitionId = 0; taskPartitionId < taskPartitionCount; taskPartitionId++) { + byte currentValue = this.getValue(exchangeId, taskPartitionId); + byte newValue = newSelector.getValue(exchangeId, taskPartitionId); + if (currentValue == UNKNOWN.getValue()) { + // transition from UNKNOWN is always valid + continue; + } + if (currentValue != newValue) { + throw new IllegalArgumentException("Invalid transition for exchange %s, taskPartitionId %s: %s -> %s".formatted(exchangeId, taskPartitionId, currentValue, newValue)); + } + } + } + } + + private int getPartitionCount(ExchangeId exchangeId) + { + Slice values = this.values.get(exchangeId); + if (values == null) { + return 0; + } + return values.length(); + } + + private byte getValue(ExchangeId exchangeId, int taskPartitionId) + { + Slice exchangeValues = values.get(exchangeId); + if (exchangeValues == null) { + return UNKNOWN.getValue(); + } + if (exchangeValues.length() <= taskPartitionId) { + return UNKNOWN.getValue(); + } + return exchangeValues.getByte(taskPartitionId); + } + + private void throwIfFinal(ExchangeId exchangeId, int taskPartitionId) + { + if (isFinal()) { + throw new IllegalArgumentException("selection not found for exchangeId %s, taskPartitionId %s".formatted(exchangeId, taskPartitionId)); + } + } + + public enum Selection + { + INCLUDED((byte) -1), + EXCLUDED((byte) -2), + UNKNOWN((byte) -3); + + private final byte value; + + Selection(byte value) + { + this.value = value; + } + + public byte getValue() + { + return value; + } + } + + public static Builder builder(Set sourceExchanges) + { + return new Builder(sourceExchanges); + } + + public static class Builder + { + private int nextVersion; + private final Map exchangeValues; + private boolean finalSelector; + private final Map exchangeTaskPartitionCount = new HashMap<>(); + + public Builder(Set sourceExchanges) + { + requireNonNull(sourceExchanges, "sourceExchanges is null"); + exchangeValues = sourceExchanges.stream() + .collect(toUnmodifiableMap(Function.identity(), exchangeId -> new ValuesBuilder())); + } + + public Builder include(ExchangeId exchangeId, int taskPartitionId, int attemptId) + { + getValuesBuilderForExchange(exchangeId).include(taskPartitionId, attemptId); + return this; + } + + public Builder exclude(ExchangeId exchangeId, int taskPartitionId) + { + getValuesBuilderForExchange(exchangeId).exclude(taskPartitionId); + return this; + } + + private ValuesBuilder getValuesBuilderForExchange(ExchangeId exchangeId) + { + ValuesBuilder result = exchangeValues.get(exchangeId); + if (result == null) { + throw new IllegalArgumentException("Unexpected exchange: " + exchangeId); + } + return result; + } + + public Builder setPartitionCount(ExchangeId exchangeId, int count) + { + Integer previousCount = exchangeTaskPartitionCount.putIfAbsent(exchangeId, count); + if (previousCount != null) { + throw new IllegalStateException("Partition count for exchange is already set: " + exchangeId); + } + return this; + } + + public Builder setFinal() + { + if (finalSelector) { + throw new IllegalStateException("selector is already marked as final"); + } + for (ExchangeId exchangeId : exchangeValues.keySet()) { + if (!exchangeTaskPartitionCount.containsKey(exchangeId)) { + throw new IllegalStateException("partition count is missing for exchange: " + exchangeId); + } + } + this.finalSelector = true; + return this; + } + + public ExchangeSourceOutputSelector build() + { + return new ExchangeSourceOutputSelector( + nextVersion++, + exchangeValues.entrySet().stream() + .collect(toMap(Map.Entry::getKey, entry -> { + ExchangeId exchangeId = entry.getKey(); + ValuesBuilder valuesBuilder = entry.getValue(); + if (finalSelector) { + return valuesBuilder.buildFinal(exchangeTaskPartitionCount.get(exchangeId)); + } + else { + return valuesBuilder.build(); + } + })), + finalSelector); + } + } + + private static class ValuesBuilder + { + private Slice values = Slices.allocate(0); + private int maxTaskPartitionId = -1; + + public void include(int taskPartitionId, int attemptId) + { + updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId); + if (attemptId < 0 || attemptId > Byte.MAX_VALUE) { + throw new IllegalArgumentException("unexpected attemptId: " + attemptId); + } + byte currentValue = values.getByte(taskPartitionId); + if (currentValue != UNKNOWN.getValue()) { + throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue)); + } + values.setByte(taskPartitionId, (byte) attemptId); + } + + public void exclude(int taskPartitionId) + { + updateMaxTaskPartitionIdAndEnsureCapacity(taskPartitionId); + byte currentValue = values.getByte(taskPartitionId); + if (currentValue != UNKNOWN.getValue()) { + throw new IllegalArgumentException("decision for partition %s is already made: %s".formatted(taskPartitionId, currentValue)); + } + values.setByte(taskPartitionId, EXCLUDED.getValue()); + } + + private void updateMaxTaskPartitionIdAndEnsureCapacity(int taskPartitionId) + { + if (taskPartitionId > maxTaskPartitionId) { + maxTaskPartitionId = taskPartitionId; + } + if (taskPartitionId < values.length()) { + return; + } + byte[] newValues = new byte[(maxTaskPartitionId + 1) * 2]; + fill(newValues, UNKNOWN.getValue()); + values.getBytes(0, newValues, 0, values.length()); + values = Slices.wrappedBuffer(newValues); + } + + public Slice build() + { + return createResult(maxTaskPartitionId + 1); + } + + public Slice buildFinal(int totalPartitionCount) + { + Slice result = createResult(totalPartitionCount); + for (int partitionId = 0; partitionId < totalPartitionCount; partitionId++) { + byte selectedAttempt = result.getByte(partitionId); + if (selectedAttempt == UNKNOWN.getValue()) { + throw new IllegalStateException("Attempt is unknown for partition: " + partitionId); + } + } + return result; + } + + private Slice createResult(int partitionCount) + { + if (maxTaskPartitionId >= partitionCount) { + throw new IllegalArgumentException("expected maxTaskPartitionId to be less than or equal to " + (partitionCount - 1)); + } + byte[] result = new byte[partitionCount]; + fill(result, UNKNOWN.getValue()); + values.getBytes(0, result, 0, maxTaskPartitionId + 1); + return Slices.wrappedBuffer(result); + } + } +} diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/ExchangeSourceFile.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/ExchangeSourceFile.java index c8f049d94060..f51545c14962 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/ExchangeSourceFile.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/ExchangeSourceFile.java @@ -13,6 +13,8 @@ */ package io.trino.plugin.exchange.filesystem; +import io.trino.spi.exchange.ExchangeId; + import javax.annotation.concurrent.Immutable; import javax.crypto.SecretKey; @@ -27,12 +29,18 @@ public class ExchangeSourceFile private final URI fileUri; private final Optional secretKey; private final long fileSize; + private final ExchangeId exchangeId; + private final int sourceTaskPartitionId; + private final int sourceTaskAttemptId; - public ExchangeSourceFile(URI fileUri, Optional secretKey, long fileSize) + public ExchangeSourceFile(URI fileUri, Optional secretKey, long fileSize, ExchangeId exchangeId, int sourceTaskPartitionId, int sourceTaskAttemptId) { this.fileUri = requireNonNull(fileUri, "fileUri is null"); this.secretKey = requireNonNull(secretKey, "secretKey is null"); this.fileSize = fileSize; + this.exchangeId = requireNonNull(exchangeId, "exchangeId is null"); + this.sourceTaskPartitionId = sourceTaskPartitionId; + this.sourceTaskAttemptId = sourceTaskAttemptId; } public URI getFileUri() @@ -49,4 +57,19 @@ public long getFileSize() { return fileSize; } + + public ExchangeId getExchangeId() + { + return exchangeId; + } + + public int getSourceTaskPartitionId() + { + return sourceTaskPartitionId; + } + + public int getSourceTaskAttemptId() + { + return sourceTaskAttemptId; + } } diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java index 41688234ca27..316efdc9c927 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchange.java @@ -20,8 +20,10 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import io.trino.plugin.exchange.filesystem.FileSystemExchangeSourceHandle.SourceFile; import io.trino.spi.exchange.Exchange; import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeSinkHandle; import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.spi.exchange.ExchangeSourceHandle; @@ -38,6 +40,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -50,6 +53,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -58,6 +62,7 @@ import static io.trino.plugin.exchange.filesystem.FileSystemExchangeManager.PATH_SEPARATOR; import static io.trino.plugin.exchange.filesystem.FileSystemExchangeSink.COMMITTED_MARKER_FILE_NAME; import static io.trino.plugin.exchange.filesystem.FileSystemExchangeSink.DATA_FILE_SUFFIX; +import static java.lang.Integer.parseInt; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -84,7 +89,7 @@ public class FileSystemExchange @GuardedBy("this") private final Set allSinks = new HashSet<>(); @GuardedBy("this") - private final Set finishedSinks = new HashSet<>(); + private final Map finishedSinks = new HashMap<>(); @GuardedBy("this") private boolean noMoreSinks; @GuardedBy("this") @@ -120,6 +125,12 @@ public FileSystemExchange( this.executor = requireNonNull(executor, "executor is null"); } + @Override + public ExchangeId getId() + { + return exchangeContext.getExchangeId(); + } + @Override public synchronized ExchangeSinkHandle addSink(int taskPartition) { @@ -134,7 +145,6 @@ public void noMoreSinks() synchronized (this) { noMoreSinks = true; } - checkInputReady(); } @Override @@ -165,29 +175,29 @@ public void sinkFinished(ExchangeSinkHandle handle, int taskAttemptId) { synchronized (this) { FileSystemExchangeSinkHandle sinkHandle = (FileSystemExchangeSinkHandle) handle; - finishedSinks.add(sinkHandle.getPartitionId()); + finishedSinks.putIfAbsent(sinkHandle.getPartitionId(), taskAttemptId); } - checkInputReady(); } - private void checkInputReady() + @Override + public void allRequiredSinksFinished() { verify(!Thread.holdsLock(this)); - ListenableFuture> exchangeSourceHandlesCreationFuture = null; + ListenableFuture> exchangeSourceHandlesCreationFuture; synchronized (this) { if (exchangeSourceHandlesCreationStarted) { return; } - if (noMoreSinks && finishedSinks.containsAll(allSinks)) { - // input is ready, create exchange source handles - exchangeSourceHandlesCreationStarted = true; - exchangeSourceHandlesCreationFuture = stats.getCreateExchangeSourceHandles().record(this::createExchangeSourceHandles); - exchangeSourceHandlesFuture.whenComplete((value, failure) -> { - if (exchangeSourceHandlesFuture.isCancelled()) { - exchangeSourceHandlesFuture.cancel(true); - } - }); - } + verify(noMoreSinks, "noMoreSinks is expected to be set"); + verify(finishedSinks.keySet().containsAll(allSinks), "all sinks are expected to be finished"); + // input is ready, create exchange source handles + exchangeSourceHandlesCreationStarted = true; + exchangeSourceHandlesCreationFuture = stats.getCreateExchangeSourceHandles().record(this::createExchangeSourceHandles); + exchangeSourceHandlesFuture.whenComplete((value, failure) -> { + if (exchangeSourceHandlesFuture.isCancelled()) { + exchangeSourceHandlesFuture.cancel(true); + } + }); } if (exchangeSourceHandlesCreationFuture != null) { Futures.addCallback(exchangeSourceHandlesCreationFuture, new FutureCallback<>() { @@ -208,24 +218,26 @@ public void onFailure(Throwable throwable) private ListenableFuture> createExchangeSourceHandles() { - List finishedTaskPartitions; + List committedTaskAttempts; synchronized (this) { - finishedTaskPartitions = ImmutableList.copyOf(finishedSinks); + committedTaskAttempts = finishedSinks.entrySet().stream() + .map(entry -> new CommittedTaskAttempt(entry.getKey(), entry.getValue())) + .collect(toImmutableList()); } return Futures.transform( - processAll(finishedTaskPartitions, this::getCommittedPartitions, fileListingParallelism, executor), + processAll(committedTaskAttempts, this::getCommittedPartitions, fileListingParallelism, executor), partitionsList -> { - Multimap partitionFiles = ArrayListMultimap.create(); - partitionsList.forEach(partitions -> partitions.forEach(partitionFiles::put)); + Multimap sourceFiles = ArrayListMultimap.create(); + partitionsList.forEach(partitions -> partitions.forEach(sourceFiles::put)); ImmutableList.Builder result = ImmutableList.builder(); - for (Integer partitionId : partitionFiles.keySet()) { - Collection files = partitionFiles.get(partitionId); + for (Integer partitionId : sourceFiles.keySet()) { + Collection files = sourceFiles.get(partitionId); long currentExchangeHandleDataSizeInBytes = 0; - ImmutableList.Builder currentExchangeHandleFiles = ImmutableList.builder(); - for (FileStatus file : files) { + ImmutableList.Builder currentExchangeHandleFiles = ImmutableList.builder(); + for (SourceFile file : files) { if (currentExchangeHandleDataSizeInBytes > 0 && currentExchangeHandleDataSizeInBytes + file.getFileSize() > exchangeSourceHandleTargetDataSizeInBytes) { - result.add(new FileSystemExchangeSourceHandle(partitionId, currentExchangeHandleFiles.build(), secretKey.map(SecretKey::getEncoded))); + result.add(new FileSystemExchangeSourceHandle(exchangeContext.getExchangeId(), partitionId, currentExchangeHandleFiles.build(), secretKey.map(SecretKey::getEncoded))); currentExchangeHandleDataSizeInBytes = 0; currentExchangeHandleFiles = ImmutableList.builder(); } @@ -233,7 +245,7 @@ private ListenableFuture> createExchangeSourceHandles currentExchangeHandleFiles.add(file); } if (currentExchangeHandleDataSizeInBytes > 0) { - result.add(new FileSystemExchangeSourceHandle(partitionId, currentExchangeHandleFiles.build(), secretKey.map(SecretKey::getEncoded))); + result.add(new FileSystemExchangeSourceHandle(exchangeContext.getExchangeId(), partitionId, currentExchangeHandleFiles.build(), secretKey.map(SecretKey::getEncoded))); } } return result.build(); @@ -241,37 +253,49 @@ private ListenableFuture> createExchangeSourceHandles executor); } - private ListenableFuture> getCommittedPartitions(int taskPartitionId) + private ListenableFuture> getCommittedPartitions(CommittedTaskAttempt committedTaskAttempt) { - URI sinkOutputPath = getTaskOutputDirectory(taskPartitionId); + URI sinkOutputPath = getTaskOutputDirectory(committedTaskAttempt.partitionId()); return stats.getGetCommittedPartitions().record(Futures.transform( exchangeStorage.listFilesRecursively(sinkOutputPath), sinkOutputFiles -> { - String committedMarkerFilePath = sinkOutputFiles.stream() + List committedMarkerFilePaths = sinkOutputFiles.stream() .map(FileStatus::getFilePath) .filter(filePath -> filePath.endsWith(COMMITTED_MARKER_FILE_NAME)) - .findFirst() - .orElseThrow(() -> new IllegalStateException(format("No committed attempts found under sink output path %s", sinkOutputPath))); - // Committed marker file path format: {sinkOutputPath}/{attemptId}/committed - String[] parts = committedMarkerFilePath.split(PATH_SEPARATOR); - checkState(parts.length >= 3, "committedMarkerFilePath %s is malformed", committedMarkerFilePath); - String committedAttemptId = parts[parts.length - 2]; - int attemptIdOffset = committedMarkerFilePath.length() - committedAttemptId.length() - - PATH_SEPARATOR.length() - COMMITTED_MARKER_FILE_NAME.length(); - - // Data output file path format: {sinkOutputPath}/{attemptId}/{sourcePartitionId}_{splitId}.data - List partitionFiles = sinkOutputFiles.stream() - .filter(file -> file.getFilePath().startsWith(committedAttemptId + PATH_SEPARATOR, attemptIdOffset) && file.getFilePath().endsWith(DATA_FILE_SUFFIX)) .collect(toImmutableList()); - ImmutableMultimap.Builder result = ImmutableMultimap.builder(); - for (FileStatus partitionFile : partitionFiles) { - Matcher matcher = PARTITION_FILE_NAME_PATTERN.matcher(new File(partitionFile.getFilePath()).getName()); - checkState(matcher.matches(), "Unexpected partition file: %s", partitionFile); - int partitionId = Integer.parseInt(matcher.group(1)); - result.put(partitionId, partitionFile); + if (committedMarkerFilePaths.isEmpty()) { + throw new IllegalStateException(format("No committed attempts found under sink output path %s", sinkOutputPath)); } - return result.build(); + + for (String committedMarkerFilePath : committedMarkerFilePaths) { + // Committed marker file path format: {sinkOutputPath}/{attemptId}/committed + String[] parts = committedMarkerFilePath.split(PATH_SEPARATOR); + checkState(parts.length >= 3, "committedMarkerFilePath %s is malformed", committedMarkerFilePath); + String stringCommittedAttemptId = parts[parts.length - 2]; + if (parseInt(stringCommittedAttemptId) != committedTaskAttempt.attemptId()) { + // skip other successful attempts + continue; + } + int attemptIdOffset = committedMarkerFilePath.length() - stringCommittedAttemptId.length() + - PATH_SEPARATOR.length() - COMMITTED_MARKER_FILE_NAME.length(); + + // Data output file path format: {sinkOutputPath}/{attemptId}/{sourcePartitionId}_{splitId}.data + List partitionFiles = sinkOutputFiles.stream() + .filter(file -> file.getFilePath().startsWith(stringCommittedAttemptId + PATH_SEPARATOR, attemptIdOffset) && file.getFilePath().endsWith(DATA_FILE_SUFFIX)) + .collect(toImmutableList()); + + ImmutableMultimap.Builder result = ImmutableMultimap.builder(); + for (FileStatus partitionFile : partitionFiles) { + Matcher matcher = PARTITION_FILE_NAME_PATTERN.matcher(new File(partitionFile.getFilePath()).getName()); + checkState(matcher.matches(), "Unexpected partition file: %s", partitionFile); + int partitionId = parseInt(matcher.group(1)); + result.put(partitionId, new SourceFile(partitionFile.getFilePath(), partitionFile.getFileSize(), committedTaskAttempt.partitionId(), committedTaskAttempt.attemptId())); + } + return result.build(); + } + + throw new IllegalArgumentException("committed attempt %s for task %s not found".formatted(committedTaskAttempt.attemptId(), committedTaskAttempt.partitionId())); }, executor)); } @@ -317,4 +341,13 @@ private static String generateRandomizedHexPrefix() } return new String(value); } + + private record CommittedTaskAttempt(int partitionId, int attemptId) + { + public CommittedTaskAttempt + { + checkArgument(partitionId >= 0, "partitionId is expected to be greater than or equal to zero: %s", partitionId); + checkArgument(attemptId >= 0, "attemptId is expected to be greater than or equal to zero: %s", attemptId); + } + } } diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSource.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSource.java index b9622c70e289..1aaf302c89b8 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSource.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSource.java @@ -22,6 +22,7 @@ import io.airlift.slice.Slice; import io.trino.spi.exchange.ExchangeSource; import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -41,10 +42,12 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; +import static io.trino.spi.exchange.ExchangeSourceOutputSelector.Selection.INCLUDED; import static java.util.Objects.requireNonNull; public class FileSystemExchangeSource @@ -59,7 +62,9 @@ public class FileSystemExchangeSource @GuardedBy("this") private boolean noMoreFiles; @GuardedBy("this") - private SettableFuture blockedOnSourceHandles = SettableFuture.create(); + private ExchangeSourceOutputSelector currentSelector; + @GuardedBy("this") + private SettableFuture blockedOnFiles = SettableFuture.create(); private final AtomicReference> readers = new AtomicReference<>(ImmutableList.of()); private final AtomicReference> blocked = new AtomicReference<>(); @@ -94,6 +99,19 @@ public synchronized void noMoreSourceHandles() closeAndCreateReadersIfNecessary(); } + @Override + public synchronized void setOutputSelector(ExchangeSourceOutputSelector newSelector) + { + if (currentSelector != null) { + if (currentSelector.getVersion() >= newSelector.getVersion()) { + return; + } + currentSelector.checkValidTransition(newSelector); + } + currentSelector = newSelector; + closeAndCreateReadersIfNecessary(); + } + @Override public CompletableFuture isBlocked() { @@ -116,8 +134,8 @@ public CompletableFuture isBlocked() } synchronized (this) { - if (!blockedOnSourceHandles.isDone()) { - blocked = blockedOnSourceHandles; + if (!blockedOnFiles.isDone()) { + blocked = blockedOnFiles; } else if (readers.isEmpty()) { // not blocked @@ -232,12 +250,16 @@ private void closeAndCreateReadersIfNecessary() return; } - SettableFuture blockedOnSourceHandlesToBeUnblocked = null; + SettableFuture blockedOnFilesToBeUnblocked = null; synchronized (this) { if (closed.get()) { return; } + if (currentSelector == null || !currentSelector.isFinal()) { + return; + } + List activeReaders = new ArrayList<>(); for (ExchangeStorageReader reader : readers.get()) { if (reader.isFinished()) { @@ -253,6 +275,11 @@ private void closeAndCreateReadersIfNecessary() long readerFileSize = 0; while (!files.isEmpty()) { ExchangeSourceFile file = files.peek(); + verify(currentSelector.getSelection(file.getExchangeId(), file.getSourceTaskPartitionId(), file.getSourceTaskAttemptId()) == INCLUDED, + "%s.%s.%s is not marked as included by the engine", + file.getExchangeId(), + file.getSourceTaskPartitionId(), + file.getSourceTaskAttemptId()); if (readerFileSize == 0 || readerFileSize + file.getFileSize() <= maxPageStorageSize + exchangeStorage.getWriteBufferSize()) { readerFiles.add(file); readerFileSize += file.getFileSize(); @@ -266,15 +293,15 @@ private void closeAndCreateReadersIfNecessary() } if (activeReaders.isEmpty()) { if (noMoreFiles) { - blockedOnSourceHandlesToBeUnblocked = blockedOnSourceHandles; + blockedOnFilesToBeUnblocked = blockedOnFiles; close(); } - else if (blockedOnSourceHandles.isDone()) { - blockedOnSourceHandles = SettableFuture.create(); + else if (blockedOnFiles.isDone()) { + blockedOnFiles = SettableFuture.create(); } } - else if (!blockedOnSourceHandles.isDone()) { - blockedOnSourceHandlesToBeUnblocked = blockedOnSourceHandles; + else if (!blockedOnFiles.isDone()) { + blockedOnFilesToBeUnblocked = blockedOnFiles; } this.readers.set(ImmutableList.copyOf(activeReaders)); } @@ -292,8 +319,8 @@ else if (!blockedOnSourceHandles.isDone()) { throw t; } } - if (blockedOnSourceHandlesToBeUnblocked != null) { - blockedOnSourceHandlesToBeUnblocked.set(null); + if (blockedOnFilesToBeUnblocked != null) { + blockedOnFilesToBeUnblocked.set(null); } } @@ -319,11 +346,14 @@ private static List getFiles(List hand Optional secretKey = handle.getSecretKey().map(key -> new SecretKeySpec(key, 0, key.length, "AES")); return new AbstractMap.SimpleEntry<>(handle, secretKey); }) - .flatMap(entry -> entry.getKey().getFiles().stream().map(fileStatus -> + .flatMap(entry -> entry.getKey().getFiles().stream().map(sourceFile -> new ExchangeSourceFile( - URI.create(fileStatus.getFilePath()), + URI.create(sourceFile.getFilePath()), entry.getValue(), - fileStatus.getFileSize()))) + sourceFile.getFileSize(), + entry.getKey().getExchangeId(), + sourceFile.getSourceTaskPartitionId(), + sourceFile.getSourceTaskAttemptId()))) .collect(toImmutableList()); } } diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java index abc021cba2da..e3bfa077a63a 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java @@ -17,10 +17,12 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.airlift.slice.SizeOf; +import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeSourceHandle; import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.Objects; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -34,21 +36,30 @@ public class FileSystemExchangeSourceHandle { private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(FileSystemExchangeSourceHandle.class).instanceSize()); + private final ExchangeId exchangeId; private final int partitionId; - private final List files; + private final List files; private final Optional secretKey; @JsonCreator public FileSystemExchangeSourceHandle( + @JsonProperty("exchangeId") ExchangeId exchangeId, @JsonProperty("partitionId") int partitionId, - @JsonProperty("files") List files, + @JsonProperty("files") List files, @JsonProperty("secretKey") Optional secretKey) { + this.exchangeId = requireNonNull(exchangeId, "exchangeId is null"); this.partitionId = partitionId; this.files = ImmutableList.copyOf(requireNonNull(files, "files is null")); this.secretKey = requireNonNull(secretKey, "secretKey is null"); } + @JsonProperty + public ExchangeId getExchangeId() + { + return exchangeId; + } + @Override @JsonProperty public int getPartitionId() @@ -60,7 +71,7 @@ public int getPartitionId() public long getDataSizeInBytes() { return files.stream() - .mapToLong(FileStatus::getFileSize) + .mapToLong(SourceFile::getFileSize) .sum(); } @@ -68,12 +79,12 @@ public long getDataSizeInBytes() public long getRetainedSizeInBytes() { return INSTANCE_SIZE - + estimatedSizeOf(files, FileStatus::getRetainedSizeInBytes) + + estimatedSizeOf(files, SourceFile::getRetainedSizeInBytes) + sizeOf(secretKey, SizeOf::sizeOf); } @JsonProperty - public List getFiles() + public List getFiles() { return files; } @@ -88,9 +99,92 @@ public Optional getSecretKey() public String toString() { return toStringHelper(this) + .add("exchangeId", exchangeId) .add("partitionId", partitionId) .add("files", files) .add("secretKey", secretKey.map(value -> "[REDACTED]")) .toString(); } + + public static class SourceFile + { + private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(SourceFile.class).instanceSize()); + + private final String filePath; + private final long fileSize; + private final int sourceTaskPartitionId; + private final int sourceTaskAttemptId; + + @JsonCreator + public SourceFile( + @JsonProperty("filePath") String filePath, + @JsonProperty("fileSize") long fileSize, + @JsonProperty("sourceTaskPartitionId") int sourceTaskPartitionId, + @JsonProperty("sourceTaskAttemptId") int sourceTaskAttemptId) + { + this.filePath = requireNonNull(filePath, "filePath is null"); + this.fileSize = fileSize; + this.sourceTaskPartitionId = sourceTaskPartitionId; + this.sourceTaskAttemptId = sourceTaskAttemptId; + } + + @JsonProperty + public String getFilePath() + { + return filePath; + } + + @JsonProperty + public long getFileSize() + { + return fileSize; + } + + @JsonProperty + public int getSourceTaskPartitionId() + { + return sourceTaskPartitionId; + } + + @JsonProperty + public int getSourceTaskAttemptId() + { + return sourceTaskAttemptId; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SourceFile that = (SourceFile) o; + return fileSize == that.fileSize && sourceTaskPartitionId == that.sourceTaskPartitionId && sourceTaskAttemptId == that.sourceTaskAttemptId && Objects.equals(filePath, that.filePath); + } + + @Override + public int hashCode() + { + return Objects.hash(filePath, fileSize, sourceTaskPartitionId, sourceTaskAttemptId); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("filePath", filePath) + .add("fileSize", fileSize) + .add("sourceTaskPartitionId", sourceTaskPartitionId) + .add("sourceTaskAttemptId", sourceTaskAttemptId) + .toString(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + estimatedSizeOf(filePath); + } + } } diff --git a/plugin/trino-exchange-filesystem/src/test/java/io/trino/plugin/exchange/filesystem/AbstractTestExchangeManager.java b/plugin/trino-exchange-filesystem/src/test/java/io/trino/plugin/exchange/filesystem/AbstractTestExchangeManager.java index d7f1963fa8e0..f7246e9d1373 100644 --- a/plugin/trino-exchange-filesystem/src/test/java/io/trino/plugin/exchange/filesystem/AbstractTestExchangeManager.java +++ b/plugin/trino-exchange-filesystem/src/test/java/io/trino/plugin/exchange/filesystem/AbstractTestExchangeManager.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import io.airlift.slice.Slice; @@ -23,6 +24,7 @@ import io.trino.spi.QueryId; import io.trino.spi.exchange.Exchange; import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeId; import io.trino.spi.exchange.ExchangeManager; import io.trino.spi.exchange.ExchangeSink; import io.trino.spi.exchange.ExchangeSinkHandle; @@ -30,6 +32,7 @@ import io.trino.spi.exchange.ExchangeSource; import io.trino.spi.exchange.ExchangeSourceHandle; import io.trino.spi.exchange.ExchangeSourceHandleSource.ExchangeSourceHandleBatch; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -79,7 +82,8 @@ public void destroy() public void testHappyPath() throws Exception { - Exchange exchange = exchangeManager.createExchange(new ExchangeContext(new QueryId("query"), createRandomExchangeId()), 2, false); + ExchangeId exchangeId = createRandomExchangeId(); + Exchange exchange = exchangeManager.createExchange(new ExchangeContext(new QueryId("query"), exchangeId), 2, false); ExchangeSinkHandle sinkHandle0 = exchange.addSink(0); ExchangeSinkHandle sinkHandle1 = exchange.addSink(1); ExchangeSinkHandle sinkHandle2 = exchange.addSink(2); @@ -151,6 +155,7 @@ public void testHappyPath() 1, "2-1-0"), true); exchange.sinkFinished(sinkHandle2, 2); + exchange.allRequiredSinksFinished(); ExchangeSourceHandleBatch sourceHandleBatch = exchange.getSourceHandles().getNextBatch().get(); assertTrue(sourceHandleBatch.lastBatch()); @@ -160,10 +165,17 @@ public void testHappyPath() Map partitions = partitionHandles.stream() .collect(toImmutableMap(ExchangeSourceHandle::getPartitionId, Function.identity())); - assertThat(readData(partitions.get(0))) + ExchangeSourceOutputSelector outputSelector = ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchangeId)) + .include(exchangeId, 0, 0) + .include(exchangeId, 1, 0) + .include(exchangeId, 2, 2) + .setPartitionCount(exchangeId, 3) + .setFinal() + .build(); + assertThat(readData(partitions.get(0), outputSelector)) .containsExactlyInAnyOrder("0-0-0", "0-0-1", "1-0-0", "1-0-1", "2-0-0"); - assertThat(readData(partitions.get(1))) + assertThat(readData(partitions.get(1), outputSelector)) .containsExactlyInAnyOrder("0-1-0", "0-1-1", "1-1-0", "1-1-1", "2-1-0"); exchange.close(); @@ -178,7 +190,8 @@ public void testLargePages() String largePage = "c".repeat(toIntExact(DataSize.of(5, MEGABYTE).toBytes()) - Integer.BYTES); String maxPage = "d".repeat(toIntExact(DataSize.of(16, MEGABYTE).toBytes()) - Integer.BYTES); - Exchange exchange = exchangeManager.createExchange(new ExchangeContext(new QueryId("query"), createRandomExchangeId()), 3, false); + ExchangeId exchangeId = createRandomExchangeId(); + Exchange exchange = exchangeManager.createExchange(new ExchangeContext(new QueryId("query"), exchangeId), 3, false); ExchangeSinkHandle sinkHandle0 = exchange.addSink(0); ExchangeSinkHandle sinkHandle1 = exchange.addSink(1); ExchangeSinkHandle sinkHandle2 = exchange.addSink(2); @@ -216,6 +229,7 @@ public void testLargePages() .build(), true); exchange.sinkFinished(sinkHandle2, 0); + exchange.allRequiredSinksFinished(); ExchangeSourceHandleBatch sourceHandleBatch = exchange.getSourceHandles().getNextBatch().get(); assertTrue(sourceHandleBatch.lastBatch()); @@ -225,13 +239,21 @@ public void testLargePages() ListMultimap partitions = partitionHandles.stream() .collect(toImmutableListMultimap(ExchangeSourceHandle::getPartitionId, Function.identity())); - assertThat(readData(partitions.get(0))) + ExchangeSourceOutputSelector outputSelector = ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchangeId)) + .include(exchangeId, 0, 0) + .include(exchangeId, 1, 0) + .include(exchangeId, 2, 0) + .setPartitionCount(exchangeId, 3) + .setFinal() + .build(); + + assertThat(readData(partitions.get(0), outputSelector)) .containsExactlyInAnyOrder(smallPage, mediumPage, largePage, maxPage); - assertThat(readData(partitions.get(1))) + assertThat(readData(partitions.get(1), outputSelector)) .containsExactlyInAnyOrder(smallPage, mediumPage, largePage, maxPage); - assertThat(readData(partitions.get(2))) + assertThat(readData(partitions.get(2), outputSelector)) .containsExactlyInAnyOrder(smallPage, mediumPage, largePage, maxPage); exchange.close(); @@ -259,15 +281,16 @@ private void writeData(ExchangeSinkInstanceHandle handle, Multimap readData(ExchangeSourceHandle handle) + private List readData(ExchangeSourceHandle handle, ExchangeSourceOutputSelector outputSelector) { - return readData(ImmutableList.of(handle)); + return readData(ImmutableList.of(handle), outputSelector); } - private List readData(List handles) + private List readData(List handles, ExchangeSourceOutputSelector outputSelector) { ImmutableList.Builder result = ImmutableList.builder(); try (ExchangeSource source = exchangeManager.createSource()) { + source.setOutputSelector(outputSelector); Queue remainingHandles = new ArrayDeque<>(handles); while (!source.isFinished()) { Slice data = source.read();