diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 9bdb99aa31b2..3a5fd2716621 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -174,6 +174,7 @@ public final class SystemSessionProperties public static final String JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT = "join_partitioned_build_min_row_count"; public static final String USE_EXACT_PARTITIONING = "use_exact_partitioning"; public static final String FORCE_SPILLING_JOIN = "force_spilling_join"; + public static final String FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED = "fault_tolerant_execution_event_driven_scheduler_enabled"; private final List> sessionProperties; @@ -859,7 +860,12 @@ public SystemSessionProperties( FORCE_SPILLING_JOIN, "Force the usage of spliing join operator in favor of the non-spilling one, even if spill is not enabled", featuresConfig.isForceSpillingJoin(), - false)); + false), + booleanProperty( + FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED, + "Enable event driven scheduler for fault tolerant execution", + queryManagerConfig.isFaultTolerantExecutionEventDrivenSchedulerEnabled(), + true)); } @Override @@ -1537,4 +1543,9 @@ public static boolean isForceSpillingOperator(Session session) { return session.getSystemProperty(FORCE_SPILLING_JOIN, Boolean.class); } + + public static boolean isFaultTolerantExecutionEventDriverSchedulerEnabled(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED, Boolean.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index 035114b191d8..7a112af23656 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -95,6 +95,7 @@ public class QueryManagerConfig private DataSize faultTolerantExecutionTaskDescriptorStorageMaxMemory = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15)); private int faultTolerantExecutionPartitionCount = 50; private boolean faultTolerantPreserveInputPartitionsInWriteStage = true; + private boolean faultTolerantExecutionEventDrivenSchedulerEnabled = true; @Min(1) public int getScheduleSplitBatchSize() @@ -628,4 +629,16 @@ public QueryManagerConfig setFaultTolerantPreserveInputPartitionsInWriteStage(bo this.faultTolerantPreserveInputPartitionsInWriteStage = faultTolerantPreserveInputPartitionsInWriteStage; return this; } + + public boolean isFaultTolerantExecutionEventDrivenSchedulerEnabled() + { + return faultTolerantExecutionEventDrivenSchedulerEnabled; + } + + @Config("experimental.fault-tolerant-execution-event-driven-scheduler-enabled") + public QueryManagerConfig setFaultTolerantExecutionEventDrivenSchedulerEnabled(boolean faultTolerantExecutionEventDrivenSchedulerEnabled) + { + this.faultTolerantExecutionEventDrivenSchedulerEnabled = faultTolerantExecutionEventDrivenSchedulerEnabled; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index c2d18595ee06..9b6c0edcfdcc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -24,6 +24,8 @@ import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.scheduler.EventDrivenFaultTolerantQueryScheduler; +import io.trino.execution.scheduler.EventDrivenTaskSourceFactory; import io.trino.execution.scheduler.FaultTolerantQueryScheduler; import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeScheduler; @@ -89,6 +91,7 @@ import static io.trino.SystemSessionProperties.getTaskRetryAttemptsOverall; import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; import static io.trino.SystemSessionProperties.isEnableDynamicFiltering; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionEventDriverSchedulerEnabled; import static io.trino.execution.QueryState.FAILED; import static io.trino.execution.QueryState.PLANNING; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; @@ -133,6 +136,7 @@ public class SqlQueryExecution private final SqlTaskManager coordinatorTaskManager; private final ExchangeManagerRegistry exchangeManagerRegistry; private final TaskSourceFactory taskSourceFactory; + private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory; private final TaskDescriptorStorage taskDescriptorStorage; private SqlQueryExecution( @@ -166,6 +170,7 @@ private SqlQueryExecution( SqlTaskManager coordinatorTaskManager, ExchangeManagerRegistry exchangeManagerRegistry, TaskSourceFactory taskSourceFactory, + EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory, TaskDescriptorStorage taskDescriptorStorage) { try (SetThreadName ignored = new SetThreadName("Query-%s", stateMachine.getQueryId())) { @@ -213,6 +218,7 @@ private SqlQueryExecution( this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); + this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "taskSourceFactory is null"); this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); } } @@ -521,28 +527,51 @@ private void planDistribution(PlanRoot plan) coordinatorTaskManager); break; case TASK: - scheduler = new FaultTolerantQueryScheduler( - stateMachine, - queryExecutor, - schedulerStats, - failureDetector, - taskSourceFactory, - taskDescriptorStorage, - exchangeManagerRegistry.getExchangeManager(), - nodePartitioningManager, - getTaskRetryAttemptsOverall(getSession()), - getTaskRetryAttemptsPerTask(getSession()), - getMaxTasksWaitingForNodePerStage(getSession()), - schedulerExecutor, - nodeAllocatorService, - partitionMemoryEstimatorFactory, - taskExecutionStats, - dynamicFilterService, - plannerContext.getMetadata(), - remoteTaskFactory, - nodeTaskMap, - plan.getRoot(), - plan.isSummarizeTaskInfos()); + if (isFaultTolerantExecutionEventDriverSchedulerEnabled(stateMachine.getSession())) { + scheduler = new EventDrivenFaultTolerantQueryScheduler( + stateMachine, + plannerContext.getMetadata(), + remoteTaskFactory, + taskDescriptorStorage, + eventDrivenTaskSourceFactory, + plan.isSummarizeTaskInfos(), + nodeTaskMap, + queryExecutor, + schedulerExecutor, + schedulerStats, + partitionMemoryEstimatorFactory, + nodePartitioningManager, + exchangeManagerRegistry.getExchangeManager(), + nodeAllocatorService, + failureDetector, + dynamicFilterService, + taskExecutionStats, + plan.getRoot()); + } + else { + scheduler = new FaultTolerantQueryScheduler( + stateMachine, + queryExecutor, + schedulerStats, + failureDetector, + taskSourceFactory, + taskDescriptorStorage, + exchangeManagerRegistry.getExchangeManager(), + nodePartitioningManager, + getTaskRetryAttemptsOverall(getSession()), + getTaskRetryAttemptsPerTask(getSession()), + getMaxTasksWaitingForNodePerStage(getSession()), + schedulerExecutor, + nodeAllocatorService, + partitionMemoryEstimatorFactory, + taskExecutionStats, + dynamicFilterService, + plannerContext.getMetadata(), + remoteTaskFactory, + nodeTaskMap, + plan.getRoot(), + plan.isSummarizeTaskInfos()); + } break; default: throw new IllegalArgumentException("Unexpected retry policy: " + retryPolicy); @@ -749,6 +778,7 @@ public static class SqlQueryExecutionFactory private final SqlTaskManager coordinatorTaskManager; private final ExchangeManagerRegistry exchangeManagerRegistry; private final TaskSourceFactory taskSourceFactory; + private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory; private final TaskDescriptorStorage taskDescriptorStorage; @Inject @@ -779,6 +809,7 @@ public static class SqlQueryExecutionFactory SqlTaskManager coordinatorTaskManager, ExchangeManagerRegistry exchangeManagerRegistry, TaskSourceFactory taskSourceFactory, + EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory, TaskDescriptorStorage taskDescriptorStorage) { this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -807,6 +838,7 @@ public static class SqlQueryExecutionFactory this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); + this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "eventDrivenTaskSourceFactory is null"); this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); } @@ -852,6 +884,7 @@ public QueryExecution createQueryExecution( coordinatorTaskManager, exchangeManagerRegistry, taskSourceFactory, + eventDrivenTaskSourceFactory, taskDescriptorStorage); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index 0ecb00f89424..29327a968b02 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -136,6 +136,11 @@ public StageId getStageId() return stateMachine.getStageId(); } + public StageState getState() + { + return stateMachine.getState(); + } + public synchronized void finish() { if (stateMachine.transitionToFinished()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/StageId.java b/core/trino-main/src/main/java/io/trino/execution/StageId.java index 52da17862149..70ff78858f07 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageId.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageId.java @@ -16,11 +16,13 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; import io.trino.spi.QueryId; +import io.trino.sql.planner.plan.PlanFragmentId; import java.util.List; import java.util.Objects; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Integer.parseInt; import static java.util.Objects.requireNonNull; public class StageId @@ -38,6 +40,11 @@ public static StageId valueOf(List ids) return new StageId(new QueryId(ids.get(0)), Integer.parseInt(ids.get(1))); } + public static StageId create(QueryId queryId, PlanFragmentId fragmentId) + { + return new StageId(queryId, parseInt(fragmentId.toString())); + } + private final QueryId queryId; private final int id; diff --git a/core/trino-main/src/main/java/io/trino/execution/StageInfo.java b/core/trino-main/src/main/java/io/trino/execution/StageInfo.java index 679a1482b997..f406f3e90722 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageInfo.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.spi.QueryId; import io.trino.spi.type.Type; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.PlanNodeId; @@ -152,6 +153,36 @@ public String toString() .toString(); } + public StageInfo withSubStages(List subStages) + { + return new StageInfo( + stageId, + state, + plan, + coordinatorOnly, + types, + stageStats, + tasks, + subStages, + tables, + failureCause); + } + + public static StageInfo createInitial(QueryId queryId, StageState state, PlanFragment fragment) + { + return new StageInfo( + StageId.create(queryId, fragment.getId()), + state, + fragment, + fragment.getPartitioning().isCoordinatorOnly(), + fragment.getTypes(), + StageStats.createInitial(), + ImmutableList.of(), + ImmutableList.of(), + ImmutableMap.of(), + null); + } + public static List getAllStages(Optional stageInfo) { if (stageInfo.isEmpty()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStats.java b/core/trino-main/src/main/java/io/trino/execution/StageStats.java index 60fa4a84bf3b..8358f27efc92 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStats.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStats.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.airlift.stats.Distribution; import io.airlift.stats.Distribution.DistributionSnapshot; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -34,9 +35,11 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.units.DataSize.Unit.BYTE; import static io.trino.execution.StageState.RUNNING; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; @Immutable public class StageStats @@ -658,4 +661,75 @@ public BasicStageStats toBasicStageStats(StageState stageState) blockedReasons, progressPercentage); } + + public static StageStats createInitial() + { + DataSize zeroBytes = DataSize.of(0, BYTE); + Duration zeroSeconds = new Duration(0, SECONDS); + return new StageStats( + null, + new Distribution().snapshot(), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + zeroBytes, + zeroBytes, + zeroBytes, + zeroBytes, + zeroBytes, + zeroSeconds, + zeroSeconds, + zeroSeconds, + zeroSeconds, + zeroSeconds, + false, + ImmutableSet.of(), + zeroBytes, + zeroBytes, + 0, + 0, + zeroSeconds, + zeroSeconds, + zeroBytes, + zeroBytes, + 0, + 0, + zeroBytes, + zeroBytes, + 0, + 0, + zeroBytes, + zeroBytes, + 0, + 0, + zeroSeconds, + zeroSeconds, + zeroBytes, + Optional.empty(), + zeroBytes, + zeroBytes, + 0, + 0, + zeroSeconds, + zeroSeconds, + zeroBytes, + zeroBytes, + new StageGcStatistics( + 0, + 0, + 0, + 0, + 0, + 0, + 0), + ImmutableList.of()); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/TableInfo.java b/core/trino-main/src/main/java/io/trino/execution/TableInfo.java index 6a2125402bff..b18b54095529 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TableInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/TableInfo.java @@ -15,14 +15,26 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.Session; +import io.trino.metadata.CatalogInfo; +import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.TableProperties; +import io.trino.metadata.TableSchema; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableScanNode; import javax.annotation.concurrent.Immutable; +import java.util.Map; import java.util.Optional; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static java.util.Objects.requireNonNull; @Immutable @@ -60,4 +72,25 @@ public TupleDomain getPredicate() { return predicate; } + + public static Map extract(Session session, Metadata metadata, PlanFragment fragment) + { + return searchFrom(fragment.getRoot()) + .where(TableScanNode.class::isInstance) + .findAll() + .stream() + .map(TableScanNode.class::cast) + .collect(toImmutableMap(PlanNode::getId, node -> extract(session, metadata, node))); + } + + private static TableInfo extract(Session session, Metadata metadata, TableScanNode node) + { + TableSchema tableSchema = metadata.getTableSchema(session, node.getTable()); + TableProperties tableProperties = metadata.getTableProperties(session, node.getTable()); + Optional connectorName = metadata.listCatalogs(session).stream() + .filter(catalogInfo -> catalogInfo.getCatalogName().equals(tableSchema.getCatalogName())) + .map(CatalogInfo::getConnectorName) + .findFirst(); + return new TableInfo(connectorName, tableSchema.getQualifiedName(), tableProperties.getPredicate()); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerde.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerde.java index 9e137ade26c1..e1539c338f76 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerde.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerde.java @@ -46,14 +46,11 @@ @NotThreadSafe public class PagesSerde { - static final int SERIALIZED_PAGE_HEADER_SIZE = /*positionCount*/ Integer.BYTES + - // pageCodecMarkers - Byte.BYTES + - // uncompressedSizeInBytes - Integer.BYTES + - // sizeInBytes - Integer.BYTES; - private static final int COMPRESSED_SIZE_OFFSET = SERIALIZED_PAGE_HEADER_SIZE - Integer.BYTES; + private static final int SERIALIZED_PAGE_POSITION_COUNT_OFFSET = 0; + private static final int SERIALIZED_PAGE_CODEC_MARKERS_OFFSET = SERIALIZED_PAGE_POSITION_COUNT_OFFSET + Integer.BYTES; + private static final int SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET = SERIALIZED_PAGE_CODEC_MARKERS_OFFSET + Byte.BYTES; + private static final int SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET = SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET + Integer.BYTES; + static final int SERIALIZED_PAGE_HEADER_SIZE = SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET + Integer.BYTES; private static final double MINIMUM_COMPRESSION_RATIO = 0.8; private final BlockEncodingSerde blockEncodingSerde; @@ -143,7 +140,12 @@ public Slice serialize(PagesSerdeContext context, Page page) public static int getSerializedPagePositionCount(Slice serializedPage) { - return serializedPage.getInt(0); + return serializedPage.getInt(SERIALIZED_PAGE_POSITION_COUNT_OFFSET); + } + + public static int getSerializedPageUncompressedSizeInBytes(Slice serializedPage) + { + return serializedPage.getInt(SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET); } public static boolean isSerializedPageEncrypted(Slice serializedPage) @@ -224,7 +226,7 @@ public static Slice readSerializedPage(Slice headerSlice, InputStream inputStrea { checkArgument(headerSlice.length() == SERIALIZED_PAGE_HEADER_SIZE, "headerSlice length should equal to %s", SERIALIZED_PAGE_HEADER_SIZE); - int compressedSize = getIntUnchecked(headerSlice, COMPRESSED_SIZE_OFFSET); + int compressedSize = getIntUnchecked(headerSlice, SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET); byte[] outputBuffer = new byte[SERIALIZED_PAGE_HEADER_SIZE + compressedSize]; headerSlice.getBytes(0, outputBuffer, 0, SERIALIZED_PAGE_HEADER_SIZE); readFully(inputStream, outputBuffer, SERIALIZED_PAGE_HEADER_SIZE, compressedSize); diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java index c3a359b02b6a..9502283f4603 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java @@ -32,6 +32,7 @@ import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; +import static io.trino.execution.buffer.PagesSerde.getSerializedPageUncompressedSizeInBytes; import static java.util.Objects.requireNonNull; @ThreadSafe @@ -193,7 +194,7 @@ public void enqueue(int partition, List pages) checkState(sink != null, "exchangeSink is null"); long dataSizeInBytes = 0; for (Slice page : pages) { - dataSizeInBytes += page.length(); + dataSizeInBytes += getSerializedPageUncompressedSizeInBytes(page); sink.add(partition, page); totalRowsAdded.addAndGet(getSerializedPagePositionCount(page)); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingOutputBuffers.java b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingOutputBuffers.java index 55979fd68c39..21244b328649 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingOutputBuffers.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingOutputBuffers.java @@ -31,6 +31,7 @@ public static SpoolingOutputBuffers createInitial(ExchangeSinkInstanceHandle exc return new SpoolingOutputBuffers(0, exchangeSinkInstanceHandle, outputPartitionCount); } + // Visible only for Jackson... Use the "with" methods instead @JsonCreator public SpoolingOutputBuffers( @JsonProperty("version") long version, @@ -70,4 +71,9 @@ public void checkValidTransition(OutputBuffers outputBuffers) getOutputPartitionCount(), newOutputBuffers.getOutputPartitionCount()); } + + public SpoolingOutputBuffers withExchangeSinkInstanceHandle(ExchangeSinkInstanceHandle handle) + { + return new SpoolingOutputBuffers(getVersion() + 1, handle, outputPartitionCount); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java new file mode 100644 index 000000000000..4dfed047bdf3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java @@ -0,0 +1,343 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import io.trino.connector.CatalogHandle; +import io.trino.exchange.SpoolingExchangeInput; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.split.RemoteSplit; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; +import static java.lang.Math.round; +import static java.util.Objects.requireNonNull; + +class ArbitraryDistributionSplitAssigner + implements SplitAssigner +{ + private final Optional catalogRequirement; + private final Set partitionedSources; + private final Set replicatedSources; + private final Set allSources; + private final long targetPartitionSizeInBytes; + private final long standardSplitSizeInBytes; + private final int maxTaskSplitCount; + + private int nextPartitionId; + private final List allAssignments = new ArrayList<>(); + private final Map, PartitionAssignment> openAssignments = new HashMap<>(); + + private final Set completedSources = new HashSet<>(); + + private final ListMultimap replicatedSplits = ArrayListMultimap.create(); + private boolean noMoreReplicatedSplits; + + ArbitraryDistributionSplitAssigner( + Optional catalogRequirement, + Set partitionedSources, + Set replicatedSources, + long targetPartitionSizeInBytes, + long standardSplitSizeInBytes, + int maxTaskSplitCount) + { + this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + this.partitionedSources = ImmutableSet.copyOf(requireNonNull(partitionedSources, "partitionedSources is null")); + this.replicatedSources = ImmutableSet.copyOf(requireNonNull(replicatedSources, "replicatedSources is null")); + allSources = ImmutableSet.builder() + .addAll(partitionedSources) + .addAll(replicatedSources) + .build(); + this.targetPartitionSizeInBytes = targetPartitionSizeInBytes; + this.standardSplitSizeInBytes = standardSplitSizeInBytes; + this.maxTaskSplitCount = maxTaskSplitCount; + } + + @Override + public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) + { + for (Split split : splits.values()) { + Optional splitCatalogRequirement = Optional.of(split.getCatalogHandle()) + .filter(catalog -> !catalog.getType().isInternal() && !catalog.equals(REMOTE_CATALOG_HANDLE)); + checkArgument( + catalogRequirement.isEmpty() || catalogRequirement.equals(splitCatalogRequirement), + "unexpected split catalog requirement: %s", + splitCatalogRequirement); + } + if (replicatedSources.contains(planNodeId)) { + return assignReplicatedSplits(planNodeId, ImmutableList.copyOf(splits.values()), noMoreSplits); + } + return assignPartitionedSplits(planNodeId, ImmutableList.copyOf(splits.values()), noMoreSplits); + } + + @Override + public AssignmentResult finish() + { + checkState(!allAssignments.isEmpty(), "allAssignments is not expected to be empty"); + return AssignmentResult.builder().build(); + } + + private AssignmentResult assignReplicatedSplits(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + AssignmentResult.Builder assignment = AssignmentResult.builder(); + replicatedSplits.putAll(planNodeId, splits); + for (PartitionAssignment partitionAssignment : allAssignments) { + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + planNodeId, + splits, + noMoreSplits)); + } + if (noMoreSplits) { + completedSources.add(planNodeId); + if (completedSources.containsAll(replicatedSources)) { + noMoreReplicatedSplits = true; + } + } + if (noMoreReplicatedSplits) { + for (PartitionAssignment partitionAssignment : allAssignments) { + if (partitionAssignment.isFull()) { + assignment.sealPartition(partitionAssignment.getPartitionId()); + } + } + } + if (completedSources.containsAll(allSources)) { + if (allAssignments.isEmpty()) { + // at least a single partition is expected to be created + allAssignments.add(new PartitionAssignment(0)); + assignment.addPartition(new Partition(0, new NodeRequirements(catalogRequirement, ImmutableSet.of()))); + for (PlanNodeId replicatedSourceId : replicatedSources) { + assignment.updatePartition(new PartitionUpdate( + 0, + replicatedSourceId, + replicatedSplits.get(replicatedSourceId), + true)); + } + assignment.sealPartition(0); + } + else { + for (PartitionAssignment partitionAssignment : allAssignments) { + // set noMoreSplits for partitioned sources + if (!partitionAssignment.isFull()) { + for (PlanNodeId partitionedSourceNodeId : partitionedSources) { + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + partitionedSourceNodeId, + ImmutableList.of(), + true)); + } + // seal partition + assignment.sealPartition(partitionAssignment.getPartitionId()); + } + } + } + replicatedSplits.clear(); + // no more partitions will be created + assignment.setNoMorePartitions(); + } + return assignment.build(); + } + + private AssignmentResult assignPartitionedSplits(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + AssignmentResult.Builder assignment = AssignmentResult.builder(); + + for (Split split : splits) { + Optional hostRequirement = getHostRequirement(split); + PartitionAssignment partitionAssignment = openAssignments.get(hostRequirement); + long splitSizeInBytes = getSplitSizeInBytes(split); + if (partitionAssignment != null && ((partitionAssignment.getAssignedDataSizeInBytes() + splitSizeInBytes > targetPartitionSizeInBytes) + || (partitionAssignment.getAssignedSplitCount() + 1 > maxTaskSplitCount))) { + partitionAssignment.setFull(true); + for (PlanNodeId partitionedSourceNodeId : partitionedSources) { + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + partitionedSourceNodeId, + ImmutableList.of(), + true)); + } + if (completedSources.containsAll(replicatedSources)) { + assignment.sealPartition(partitionAssignment.getPartitionId()); + } + partitionAssignment = null; + openAssignments.remove(hostRequirement); + } + if (partitionAssignment == null) { + partitionAssignment = new PartitionAssignment(nextPartitionId++); + allAssignments.add(partitionAssignment); + openAssignments.put(hostRequirement, partitionAssignment); + assignment.addPartition(new Partition( + partitionAssignment.getPartitionId(), + new NodeRequirements(catalogRequirement, hostRequirement.map(ImmutableSet::of).orElseGet(ImmutableSet::of)))); + + for (PlanNodeId replicatedSourceId : replicatedSources) { + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + replicatedSourceId, + replicatedSplits.get(replicatedSourceId), + completedSources.contains(replicatedSourceId))); + } + } + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + planNodeId, + ImmutableList.of(split), + false)); + partitionAssignment.assignSplit(splitSizeInBytes); + } + + if (noMoreSplits) { + completedSources.add(planNodeId); + } + + if (completedSources.containsAll(allSources)) { + if (allAssignments.isEmpty()) { + // at least a single partition is expected to be created + allAssignments.add(new PartitionAssignment(0)); + assignment.addPartition(new Partition(0, new NodeRequirements(catalogRequirement, ImmutableSet.of()))); + for (PlanNodeId replicatedSourceId : replicatedSources) { + assignment.updatePartition(new PartitionUpdate( + 0, + replicatedSourceId, + replicatedSplits.get(replicatedSourceId), + true)); + } + assignment.sealPartition(0); + } + else { + for (PartitionAssignment partitionAssignment : openAssignments.values()) { + // set noMoreSplits for partitioned sources + for (PlanNodeId partitionedSourceNodeId : partitionedSources) { + assignment.updatePartition(new PartitionUpdate( + partitionAssignment.getPartitionId(), + partitionedSourceNodeId, + ImmutableList.of(), + true)); + } + // seal partition + assignment.sealPartition(partitionAssignment.getPartitionId()); + } + openAssignments.clear(); + } + replicatedSplits.clear(); + // no more partitions will be created + assignment.setNoMorePartitions(); + } + + return assignment.build(); + } + + private Optional getHostRequirement(Split split) + { + if (split.getConnectorSplit().isRemotelyAccessible()) { + return Optional.empty(); + } + List addresses = split.getAddresses(); + checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of hosts is empty: %s", split); + HostAddress selectedAddress = null; + long selectedAssignmentDataSize = Long.MAX_VALUE; + for (HostAddress address : addresses) { + PartitionAssignment assignment = openAssignments.get(Optional.of(address)); + if (assignment == null) { + // prioritize unused addresses + selectedAddress = address; + break; + } + if (assignment.getAssignedDataSizeInBytes() < selectedAssignmentDataSize) { + // otherwise prioritize the smallest assignment + selectedAddress = address; + selectedAssignmentDataSize = assignment.getAssignedDataSizeInBytes(); + } + } + verify(selectedAddress != null, "selectedAddress is null"); + return Optional.of(selectedAddress); + } + + private long getSplitSizeInBytes(Split split) + { + if (split.getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { + RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); + SpoolingExchangeInput exchangeInput = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); + long size = 0; + for (ExchangeSourceHandle handle : exchangeInput.getExchangeSourceHandles()) { + size += handle.getDataSizeInBytes(); + } + return size; + } + return round(((split.getSplitWeight().getRawValue() * 1.0) / SplitWeight.standard().getRawValue()) * standardSplitSizeInBytes); + } + + private static class PartitionAssignment + { + private final int partitionId; + private long assignedDataSizeInBytes; + private int assignedSplitCount; + private boolean full; + + private PartitionAssignment(int partitionId) + { + this.partitionId = partitionId; + } + + public int getPartitionId() + { + return partitionId; + } + + public void assignSplit(long sizeInBytes) + { + assignedDataSizeInBytes += sizeInBytes; + assignedSplitCount++; + } + + public long getAssignedDataSizeInBytes() + { + return assignedDataSizeInBytes; + } + + public int getAssignedSplitCount() + { + return assignedSplitCount; + } + + public boolean isFull() + { + return full; + } + + public void setFull(boolean full) + { + this.full = full; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java new file mode 100644 index 000000000000..3480b594bf2d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java @@ -0,0 +1,2108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.base.Stopwatch; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.SetMultimap; +import com.google.common.collect.Sets; +import com.google.common.graph.Traverser; +import com.google.common.io.Closer; +import com.google.common.primitives.ImmutableIntArray; +import com.google.common.primitives.ImmutableLongArray; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.UncheckedExecutionException; +import io.airlift.log.Logger; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.exchange.SpoolingExchangeInput; +import io.trino.execution.BasicStageStats; +import io.trino.execution.ExecutionFailureInfo; +import io.trino.execution.NodeTaskMap; +import io.trino.execution.QueryState; +import io.trino.execution.QueryStateMachine; +import io.trino.execution.RemoteTask; +import io.trino.execution.RemoteTaskFactory; +import io.trino.execution.SqlStage; +import io.trino.execution.StageId; +import io.trino.execution.StageInfo; +import io.trino.execution.StageState; +import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.TableInfo; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; +import io.trino.execution.TaskStatus; +import io.trino.execution.buffer.OutputBufferStatus; +import io.trino.execution.buffer.SpoolingOutputBuffers; +import io.trino.execution.buffer.SpoolingOutputStats; +import io.trino.execution.resourcegroups.IndexedPriorityQueue; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.execution.scheduler.NodeAllocator.NodeLease; +import io.trino.execution.scheduler.PartitionMemoryEstimator.MemoryRequirements; +import io.trino.failuredetector.FailureDetector; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Metadata; +import io.trino.metadata.Split; +import io.trino.operator.RetryPolicy; +import io.trino.server.DynamicFilterService; +import io.trino.spi.ErrorCode; +import io.trino.spi.StandardErrorCode; +import io.trino.spi.TrinoException; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeId; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceOutputSelector; +import io.trino.split.RemoteSplit; +import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.io.Closeable; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.util.concurrent.Futures.getDone; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionPartitionCount; +import static io.trino.SystemSessionProperties.getMaxTasksWaitingForNodePerStage; +import static io.trino.SystemSessionProperties.getRetryDelayScaleFactor; +import static io.trino.SystemSessionProperties.getRetryInitialDelay; +import static io.trino.SystemSessionProperties.getRetryMaxDelay; +import static io.trino.SystemSessionProperties.getRetryPolicy; +import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; +import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; +import static io.trino.execution.StageState.ABORTED; +import static io.trino.execution.StageState.PLANNED; +import static io.trino.execution.scheduler.ErrorCodes.isOutOfMemoryError; +import static io.trino.execution.scheduler.Exchanges.getAllSourceHandles; +import static io.trino.failuredetector.FailureDetector.State.GONE; +import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; +import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.spi.ErrorType.EXTERNAL; +import static io.trino.spi.ErrorType.INTERNAL_ERROR; +import static io.trino.spi.ErrorType.USER_ERROR; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; +import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; +import static io.trino.util.Failures.toFailure; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.round; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; + +public class EventDrivenFaultTolerantQueryScheduler + implements QueryScheduler +{ + private static final Logger log = Logger.get(EventDrivenFaultTolerantQueryScheduler.class); + + private final QueryStateMachine queryStateMachine; + private final Metadata metadata; + private final RemoteTaskFactory remoteTaskFactory; + private final TaskDescriptorStorage taskDescriptorStorage; + private final EventDrivenTaskSourceFactory taskSourceFactory; + private final boolean summarizeTaskInfo; + private final NodeTaskMap nodeTaskMap; + private final ExecutorService queryExecutor; + private final ScheduledExecutorService scheduledExecutorService; + private final SplitSchedulerStats schedulerStats; + private final PartitionMemoryEstimatorFactory memoryEstimatorFactory; + private final NodePartitioningManager nodePartitioningManager; + private final ExchangeManager exchangeManager; + private final NodeAllocatorService nodeAllocatorService; + private final FailureDetector failureDetector; + private final DynamicFilterService dynamicFilterService; + private final TaskExecutionStats taskExecutionStats; + private final SubPlan originalPlan; + + private final StageRegistry stageRegistry; + + @GuardedBy("this") + private boolean started; + @GuardedBy("this") + private Scheduler scheduler; + + public EventDrivenFaultTolerantQueryScheduler( + QueryStateMachine queryStateMachine, + Metadata metadata, + RemoteTaskFactory remoteTaskFactory, + TaskDescriptorStorage taskDescriptorStorage, + EventDrivenTaskSourceFactory taskSourceFactory, + boolean summarizeTaskInfo, + NodeTaskMap nodeTaskMap, + ExecutorService queryExecutor, + ScheduledExecutorService scheduledExecutorService, + SplitSchedulerStats schedulerStats, + PartitionMemoryEstimatorFactory memoryEstimatorFactory, + NodePartitioningManager nodePartitioningManager, + ExchangeManager exchangeManager, + NodeAllocatorService nodeAllocatorService, + FailureDetector failureDetector, + DynamicFilterService dynamicFilterService, + TaskExecutionStats taskExecutionStats, + SubPlan originalPlan) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + RetryPolicy retryPolicy = getRetryPolicy(queryStateMachine.getSession()); + verify(retryPolicy == TASK, "unexpected retry policy: %s", retryPolicy); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); + this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); + this.summarizeTaskInfo = summarizeTaskInfo; + this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); + this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null"); + this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); + this.memoryEstimatorFactory = requireNonNull(memoryEstimatorFactory, "memoryEstimatorFactory is null"); + this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "partitioningSchemeFactory is null"); + this.exchangeManager = requireNonNull(exchangeManager, "exchangeManager is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null"); + this.originalPlan = requireNonNull(originalPlan, "originalPlan is null"); + + stageRegistry = new StageRegistry(queryStateMachine, originalPlan); + } + + @Override + public synchronized void start() + { + checkState(!started, "already started"); + started = true; + + if (queryStateMachine.isDone()) { + return; + } + + taskDescriptorStorage.initialize(queryStateMachine.getQueryId()); + queryStateMachine.addStateChangeListener(state -> { + if (state.isDone()) { + taskDescriptorStorage.destroy(queryStateMachine.getQueryId()); + } + }); + + // when query is done or any time a stage completes, attempt to transition query to "final query info ready" + queryStateMachine.addStateChangeListener(state -> { + if (!state.isDone()) { + return; + } + Scheduler scheduler; + synchronized (this) { + scheduler = this.scheduler; + this.scheduler = null; + } + if (scheduler != null) { + scheduler.abort(); + } + queryStateMachine.updateQueryInfo(Optional.ofNullable(stageRegistry.getStageInfo())); + }); + + Session session = queryStateMachine.getSession(); + FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory = new FaultTolerantPartitioningSchemeFactory( + nodePartitioningManager, + session, + getFaultTolerantExecutionPartitionCount(session)); + Closer closer = Closer.create(); + NodeAllocator nodeAllocator = closer.register(nodeAllocatorService.getNodeAllocator(session)); + try { + scheduler = new Scheduler( + queryStateMachine, + metadata, + remoteTaskFactory, + taskDescriptorStorage, + taskSourceFactory, + summarizeTaskInfo, + nodeTaskMap, + queryExecutor, + scheduledExecutorService, schedulerStats, + memoryEstimatorFactory, + partitioningSchemeFactory, + exchangeManager, + getTaskRetryAttemptsPerTask(session) + 1, + getMaxTasksWaitingForNodePerStage(session), + nodeAllocator, + failureDetector, + stageRegistry, + taskExecutionStats, + dynamicFilterService, + new SchedulingDelayer( + getRetryInitialDelay(session), + getRetryMaxDelay(session), + getRetryDelayScaleFactor(session), + Stopwatch.createUnstarted()), + originalPlan); + queryExecutor.submit(scheduler::run); + } + catch (Throwable t) { + try { + closer.close(); + } + catch (Throwable closerFailure) { + if (t != closerFailure) { + t.addSuppressed(closerFailure); + } + } + throw t; + } + } + + @Override + public void cancelStage(StageId stageId) + { + throw new UnsupportedOperationException("partial cancel is not supported in fault tolerant mode"); + } + + @Override + public void failTask(TaskId taskId, Throwable failureCause) + { + stageRegistry.failTaskRemotely(taskId, failureCause); + } + + @Override + public BasicStageStats getBasicStageStats() + { + return stageRegistry.getBasicStageStats(); + } + + @Override + public StageInfo getStageInfo() + { + return stageRegistry.getStageInfo(); + } + + @Override + public long getUserMemoryReservation() + { + return stageRegistry.getUserMemoryReservation(); + } + + @Override + public long getTotalMemoryReservation() + { + return stageRegistry.getTotalMemoryReservation(); + } + + @Override + public Duration getTotalCpuTime() + { + return stageRegistry.getTotalCpuTime(); + } + + @ThreadSafe + private static class StageRegistry + { + private final QueryStateMachine queryStateMachine; + private final AtomicReference plan; + private final Map stages = new ConcurrentHashMap<>(); + + public StageRegistry(QueryStateMachine queryStateMachine, SubPlan plan) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.plan = new AtomicReference<>(requireNonNull(plan, "plan is null")); + } + + public void add(SqlStage stage) + { + verify(stages.putIfAbsent(stage.getStageId(), stage) == null, "stage %s is already present", stage.getStageId()); + } + + public void updatePlan(SubPlan plan) + { + this.plan.set(requireNonNull(plan, "plan is null")); + } + + public StageInfo getStageInfo() + { + SubPlan plan = requireNonNull(this.plan.get(), "plan is null"); + Map stageInfos = stages.values().stream() + .collect(toImmutableMap(stage -> stage.getFragment().getId(), SqlStage::getStageInfo)); + Set reportedFragments = new HashSet<>(); + StageInfo stageInfo = getStageInfo(plan, stageInfos, reportedFragments); + // TODO Some stages may no longer be present in the plan when adaptive re-planning is implemented + // TODO Figure out how to report statistics for such stages + verify(reportedFragments.containsAll(stageInfos.keySet()), "some stages are left unreported"); + return stageInfo; + } + + private StageInfo getStageInfo(SubPlan plan, Map infos, Set reportedFragments) + { + PlanFragmentId fragmentId = plan.getFragment().getId(); + reportedFragments.add(fragmentId); + StageInfo info = infos.get(fragmentId); + if (info == null) { + info = StageInfo.createInitial( + queryStateMachine.getQueryId(), + queryStateMachine.getQueryState().isDone() ? ABORTED : PLANNED, + plan.getFragment()); + } + List children = plan.getChildren().stream() + .map(child -> getStageInfo(child, infos, reportedFragments)) + .collect(toImmutableList()); + return info.withSubStages(children); + } + + public BasicStageStats getBasicStageStats() + { + List stageStats = stages.values().stream() + .map(SqlStage::getBasicStageStats) + .collect(toImmutableList()); + return aggregateBasicStageStats(stageStats); + } + + public long getUserMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStage::getUserMemoryReservation) + .sum(); + } + + public long getTotalMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStage::getTotalMemoryReservation) + .sum(); + } + + public Duration getTotalCpuTime() + { + long millis = stages.values().stream() + .mapToLong(stage -> stage.getTotalCpuTime().toMillis()) + .sum(); + return new Duration(millis, MILLISECONDS); + } + + public void failTaskRemotely(TaskId taskId, Throwable failureCause) + { + SqlStage sqlStage = requireNonNull(stages.get(taskId.getStageId()), () -> "stage not found: %s" + taskId.getStageId()); + sqlStage.failTaskRemotely(taskId, failureCause); + } + } + + private static class Scheduler + implements EventListener + { + private static final int EVENT_BUFFER_CAPACITY = 100; + + private final QueryStateMachine queryStateMachine; + private final Metadata metadata; + private final RemoteTaskFactory remoteTaskFactory; + private final TaskDescriptorStorage taskDescriptorStorage; + private final EventDrivenTaskSourceFactory taskSourceFactory; + private final boolean summarizeTaskInfo; + private final NodeTaskMap nodeTaskMap; + private final ExecutorService queryExecutor; + private final ScheduledExecutorService scheduledExecutorService; + private final SplitSchedulerStats schedulerStats; + private final PartitionMemoryEstimatorFactory memoryEstimatorFactory; + private final FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory; + private final ExchangeManager exchangeManager; + private final int maxTaskExecutionAttempts; + private final int maxTasksWaitingForNode; + private final NodeAllocator nodeAllocator; + private final FailureDetector failureDetector; + private final StageRegistry stageRegistry; + private final TaskExecutionStats taskExecutionStats; + private final DynamicFilterService dynamicFilterService; + + private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(); + private final List eventBuffer = new ArrayList<>(EVENT_BUFFER_CAPACITY); + + private boolean started; + + private SubPlan plan; + private List planInTopologicalOrder; + private final Map stageExecutions = new HashMap<>(); + private final SetMultimap stageConsumers = HashMultimap.create(); + + private final IndexedPriorityQueue schedulingQueue = new IndexedPriorityQueue<>(); + private int nextSchedulingPriority; + + private final Map nodeAcquisitions = new HashMap<>(); + + private final SchedulingDelayer schedulingDelayer; + + private boolean queryOutputSet; + + public Scheduler( + QueryStateMachine queryStateMachine, + Metadata metadata, + RemoteTaskFactory remoteTaskFactory, + TaskDescriptorStorage taskDescriptorStorage, + EventDrivenTaskSourceFactory taskSourceFactory, + boolean summarizeTaskInfo, + NodeTaskMap nodeTaskMap, + ExecutorService queryExecutor, + ScheduledExecutorService scheduledExecutorService, + SplitSchedulerStats schedulerStats, + PartitionMemoryEstimatorFactory memoryEstimatorFactory, + FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory, + ExchangeManager exchangeManager, + int maxTaskExecutionAttempts, + int maxTasksWaitingForNode, + NodeAllocator nodeAllocator, + FailureDetector failureDetector, + StageRegistry stageRegistry, + TaskExecutionStats taskExecutionStats, + DynamicFilterService dynamicFilterService, + SchedulingDelayer schedulingDelayer, + SubPlan plan) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); + this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); + this.summarizeTaskInfo = summarizeTaskInfo; + this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); + this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null"); + this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); + this.memoryEstimatorFactory = requireNonNull(memoryEstimatorFactory, "memoryEstimatorFactory is null"); + this.partitioningSchemeFactory = requireNonNull(partitioningSchemeFactory, "partitioningSchemeFactory is null"); + this.exchangeManager = requireNonNull(exchangeManager, "exchangeManager is null"); + checkArgument(maxTaskExecutionAttempts > 0, "maxTaskExecutionAttempts must be greater than zero: %s", maxTaskExecutionAttempts); + this.maxTaskExecutionAttempts = maxTaskExecutionAttempts; + this.maxTasksWaitingForNode = maxTasksWaitingForNode; + this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); + this.stageRegistry = requireNonNull(stageRegistry, "stageRegistry is null"); + this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null"); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.schedulingDelayer = requireNonNull(schedulingDelayer, "schedulingDelayer is null"); + this.plan = requireNonNull(plan, "plan is null"); + + planInTopologicalOrder = sortPlanInTopologicalOrder(plan); + } + + public void run() + { + checkState(!started, "already started"); + started = true; + + queryStateMachine.addStateChangeListener(state -> { + if (state.isDone()) { + eventQueue.add(Event.WAKE_UP); + } + }); + + Optional failure = Optional.empty(); + try { + if (schedule()) { + while (processEvents()) { + if (schedulingDelayer.getRemainingDelayInMillis() > 0) { + continue; + } + if (!schedule()) { + break; + } + } + } + } + catch (Throwable t) { + failure = Optional.of(t); + } + + for (StageExecution execution : stageExecutions.values()) { + failure = closeAndAddSuppressed(failure, execution::abort); + } + for (NodeLease nodeLease : nodeAcquisitions.values()) { + failure = closeAndAddSuppressed(failure, nodeLease::release); + } + nodeAcquisitions.clear(); + failure = closeAndAddSuppressed(failure, nodeAllocator); + + failure.ifPresent(queryStateMachine::transitionToFailed); + } + + private Optional closeAndAddSuppressed(Optional existingFailure, Closeable closeable) + { + try { + closeable.close(); + } + catch (Throwable t) { + if (existingFailure.isEmpty()) { + return Optional.of(t); + } + if (existingFailure.get() != t) { + existingFailure.get().addSuppressed(t); + } + } + return existingFailure; + } + + private boolean processEvents() + { + try { + Event event = eventQueue.poll(1, MINUTES); + if (event == null) { + return true; + } + eventBuffer.add(event); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + + while (true) { + // poll multiple events from the queue in one shot to improve efficiency + eventQueue.drainTo(eventBuffer, EVENT_BUFFER_CAPACITY - eventBuffer.size()); + if (eventBuffer.isEmpty()) { + return true; + } + for (Event e : eventBuffer) { + if (e == Event.ABORT) { + return false; + } + if (e == Event.WAKE_UP) { + continue; + } + e.accept(this); + } + eventBuffer.clear(); + } + } + + private boolean schedule() + { + if (checkComplete()) { + return false; + } + optimize(); + updateStageExecutions(); + scheduleTasks(); + processNodeAcquisitions(); + return true; + } + + private boolean checkComplete() + { + if (queryStateMachine.isDone()) { + return true; + } + + for (StageExecution execution : stageExecutions.values()) { + if (execution.getState() == StageState.FAILED) { + StageInfo stageInfo = execution.getStageInfo(); + ExecutionFailureInfo failureCause = stageInfo.getFailureCause(); + RuntimeException failure = failureCause == null ? + new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "stage failed due to unknown error: %s".formatted(execution.getStageId())) : + failureCause.toException(); + queryStateMachine.transitionToFailed(failure); + return true; + } + } + setQueryOutputIfReady(); + return false; + } + + private void setQueryOutputIfReady() + { + StageId rootStageId = getStageId(plan.getFragment().getId()); + StageExecution rootStageExecution = stageExecutions.get(rootStageId); + if (!queryOutputSet && rootStageExecution != null && rootStageExecution.getState() == StageState.FINISHED) { + ListenableFuture> sourceHandles = getAllSourceHandles(rootStageExecution.getExchange().getSourceHandles()); + Futures.addCallback(sourceHandles, new FutureCallback<>() + { + @Override + public void onSuccess(List handles) + { + try { + queryStateMachine.updateInputsForQueryResults( + ImmutableList.of(new SpoolingExchangeInput(handles, Optional.of(rootStageExecution.getSinkOutputSelector()))), + true); + queryStateMachine.transitionToFinishing(); + } + catch (Throwable t) { + onFailure(t); + } + } + + @Override + public void onFailure(Throwable t) + { + queryStateMachine.transitionToFailed(t); + } + }, queryExecutor); + queryOutputSet = true; + } + } + + private void optimize() + { + plan = optimizePlan(plan); + planInTopologicalOrder = sortPlanInTopologicalOrder(plan); + stageRegistry.updatePlan(plan); + } + + private SubPlan optimizePlan(SubPlan plan) + { + // Re-optimize plan here based on available runtime statistics. + // Fragments changed due to re-optimization as well as their downstream stages are expected to be assigned new fragment ids. + return plan; + } + + private void updateStageExecutions() + { + Set currentPlanStages = new HashSet<>(); + PlanFragmentId rootFragmentId = plan.getFragment().getId(); + for (SubPlan subPlan : planInTopologicalOrder) { + PlanFragmentId fragmentId = subPlan.getFragment().getId(); + StageId stageId = getStageId(fragmentId); + currentPlanStages.add(stageId); + if (isReadyForExecution(subPlan) && !stageExecutions.containsKey(stageId)) { + createStageExecution(subPlan, fragmentId.equals(rootFragmentId), nextSchedulingPriority++); + } + } + stageExecutions.forEach((stageId, stageExecution) -> { + if (!currentPlanStages.contains(stageId)) { + // stage got re-written during re-optimization + stageExecution.abort(); + } + }); + } + + private boolean isReadyForExecution(SubPlan subPlan) + { + for (SubPlan child : subPlan.getChildren()) { + StageExecution childExecution = stageExecutions.get(getStageId(child.getFragment().getId())); + if (childExecution == null) { + return false; + } + // TODO enable speculative execution + if (childExecution.getState() != StageState.FINISHED) { + return false; + } + } + return true; + } + + private void createStageExecution(SubPlan subPlan, boolean rootFragment, int schedulingPriority) + { + Closer closer = Closer.create(); + + try { + PlanFragment fragment = subPlan.getFragment(); + Session session = queryStateMachine.getSession(); + + StageId stageId = getStageId(fragment.getId()); + SqlStage stage = SqlStage.createSqlStage( + stageId, + fragment, + TableInfo.extract(session, metadata, fragment), + remoteTaskFactory, + session, + summarizeTaskInfo, + nodeTaskMap, + queryExecutor, + schedulerStats); + closer.register(stage::abort); + stageRegistry.add(stage); + stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(stageRegistry.getStageInfo()))); + + ImmutableMap.Builder sourceExchanges = ImmutableMap.builder(); + Map outputEstimates = new HashMap<>(); + for (SubPlan child : subPlan.getChildren()) { + PlanFragmentId childFragmentId = child.getFragment().getId(); + StageExecution childExecution = getStageExecution(getStageId(childFragmentId)); + sourceExchanges.put(childFragmentId, childExecution.getExchange()); + outputEstimates.put(childFragmentId, childExecution.getOutputDataSize()); + stageConsumers.put(childExecution.getStageId(), stageId); + } + + ImmutableMap.Builder outputDataSizeEstimates = ImmutableMap.builder(); + for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { + List estimates = new ArrayList<>(); + for (PlanFragmentId fragmentId : remoteSource.getSourceFragmentIds()) { + OutputDataSizeEstimate fragmentEstimate = outputEstimates.get(fragmentId); + verify(fragmentEstimate != null, "fragmentEstimate not found for fragment %s", fragmentId); + estimates.add(fragmentEstimate); + } + // merge estimates for all source fragments of a single remote source + outputDataSizeEstimates.put(remoteSource.getId(), OutputDataSizeEstimate.merge(estimates)); + } + + EventDrivenTaskSource taskSource = closer.register(taskSourceFactory.create( + createTaskSourceCallback(stageId), + session, + fragment, + sourceExchanges.buildOrThrow(), + partitioningSchemeFactory.get(fragment.getPartitioning()), + stage::recordGetSplitTime, + outputDataSizeEstimates.buildOrThrow())); + taskSource.start(); + + FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioningScheme().getPartitioning().getHandle()); + ExchangeContext exchangeContext = new ExchangeContext(queryStateMachine.getQueryId(), new ExchangeId("external-exchange-" + stage.getStageId().getId())); + Exchange exchange = closer.register(exchangeManager.createExchange( + exchangeContext, + sinkPartitioningScheme.getPartitionCount(), + rootFragment)); + + boolean coordinatorStage = stage.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION); + + StageExecution execution = new StageExecution( + queryStateMachine, + taskDescriptorStorage, + stage, + taskSource, + sinkPartitioningScheme, + exchange, + memoryEstimatorFactory.createPartitionMemoryEstimator(), + // do not retry coordinator only tasks + coordinatorStage ? 1 : maxTaskExecutionAttempts, + schedulingPriority, + dynamicFilterService); + + stageExecutions.put(execution.getStageId(), execution); + + for (SubPlan child : subPlan.getChildren()) { + PlanFragmentId childFragmentId = child.getFragment().getId(); + StageExecution childExecution = getStageExecution(getStageId(childFragmentId)); + execution.setSourceOutputSelector(childFragmentId, childExecution.getSinkOutputSelector()); + } + } + catch (Throwable t) { + try { + closer.close(); + } + catch (Throwable closerFailure) { + if (closerFailure != t) { + t.addSuppressed(closerFailure); + } + } + throw t; + } + } + + private StageId getStageId(PlanFragmentId fragmentId) + { + return StageId.create(queryStateMachine.getQueryId(), fragmentId); + } + + private EventDrivenTaskSource.Callback createTaskSourceCallback(StageId stageId) + { + return new EventDrivenTaskSource.Callback() + { + @Override + public void partitionsAdded(List partitions) + { + eventQueue.add(new PartitionsAddedEvent(stageId, partitions)); + } + + @Override + public void noMorePartitions() + { + eventQueue.add(new NoMorePartitionsEvent(stageId)); + } + + @Override + public void partitionsUpdated(List partitionUpdates) + { + eventQueue.add(new PartitionsUpdatedEvent(stageId, partitionUpdates)); + } + + @Override + public void partitionsSealed(ImmutableIntArray partitionIds) + { + eventQueue.add(new PartitionsSealedEvent(stageId, partitionIds)); + } + + @Override + public void failed(Throwable t) + { + eventQueue.add(new TaskSourceFailureEvent(stageId, t)); + } + }; + } + + private void scheduleTasks() + { + while (nodeAcquisitions.size() < maxTasksWaitingForNode && !schedulingQueue.isEmpty()) { + ScheduledTask scheduledTask = schedulingQueue.poll(); + verify(scheduledTask != null, "scheduledTask is null"); + StageExecution stageExecution = getStageExecution(scheduledTask.stageId()); + if (stageExecution.getState().isDone()) { + continue; + } + int partitionId = scheduledTask.partitionId(); + Optional nodeRequirements = stageExecution.getNodeRequirements(partitionId); + if (nodeRequirements.isEmpty()) { + // execution finished + continue; + } + MemoryRequirements memoryRequirements = stageExecution.getMemoryRequirements(partitionId); + NodeLease lease = nodeAllocator.acquire(nodeRequirements.get(), memoryRequirements.getRequiredMemory()); + lease.getNode().addListener(() -> eventQueue.add(Event.WAKE_UP), queryExecutor); + nodeAcquisitions.put(scheduledTask, lease); + } + } + + private void processNodeAcquisitions() + { + Iterator> nodeAcquisitionIterator = nodeAcquisitions.entrySet().iterator(); + while (nodeAcquisitionIterator.hasNext()) { + Map.Entry nodeAcquisition = nodeAcquisitionIterator.next(); + ScheduledTask scheduledTask = nodeAcquisition.getKey(); + NodeLease nodeLease = nodeAcquisition.getValue(); + StageExecution stageExecution = getStageExecution(scheduledTask.stageId()); + if (stageExecution.getState().isDone()) { + nodeAcquisitionIterator.remove(); + nodeLease.release(); + } + else if (nodeLease.getNode().isDone()) { + nodeAcquisitionIterator.remove(); + try { + InternalNode node = getDone(nodeLease.getNode()); + Optional remoteTask = stageExecution.schedule(scheduledTask.partitionId(), node); + remoteTask.ifPresent(task -> { + task.addStateChangeListener(createExchangeSinkInstanceHandleUpdateRequiredListener()); + task.addStateChangeListener(taskStatus -> { + if (taskStatus.getState().isDone()) { + nodeLease.release(); + } + }); + task.addFinalTaskInfoListener(taskExecutionStats::update); + task.addFinalTaskInfoListener(taskInfo -> eventQueue.add(new RemoteTaskCompletedEvent(taskInfo.getTaskStatus()))); + nodeLease.attachTaskId(task.getTaskId()); + task.start(); + if (queryStateMachine.getQueryState() == QueryState.STARTING) { + queryStateMachine.transitionToRunning(); + } + }); + if (remoteTask.isEmpty()) { + nodeLease.release(); + } + } + catch (ExecutionException e) { + throw new UncheckedExecutionException(e); + } + } + } + } + + private StateChangeListener createExchangeSinkInstanceHandleUpdateRequiredListener() + { + AtomicLong respondedToVersion = new AtomicLong(-1); + return taskStatus -> { + OutputBufferStatus outputBufferStatus = taskStatus.getOutputBufferStatus(); + if (outputBufferStatus.getOutputBuffersVersion().isEmpty()) { + return; + } + if (!outputBufferStatus.isExchangeSinkInstanceHandleUpdateRequired()) { + return; + } + long remoteVersion = outputBufferStatus.getOutputBuffersVersion().getAsLong(); + while (true) { + long localVersion = respondedToVersion.get(); + if (remoteVersion <= localVersion) { + // version update is scheduled or sent already but got not propagated yet + break; + } + if (respondedToVersion.compareAndSet(localVersion, remoteVersion)) { + eventQueue.add(new RemoteTaskExchangeSinkUpdateRequiredEvent(taskStatus)); + break; + } + } + }; + } + + public void abort() + { + eventQueue.clear(); + eventQueue.add(Event.ABORT); + } + + @Override + public void onRemoteTaskCompleted(RemoteTaskCompletedEvent event) + { + TaskStatus taskStatus = event.getTaskStatus(); + TaskId taskId = taskStatus.getTaskId(); + TaskState taskState = taskStatus.getState(); + StageExecution stageExecution = getStageExecution(taskId.getStageId()); + if (taskState == TaskState.FINISHED) { + stageExecution.taskFinished(taskId, taskStatus); + } + else if (taskState == TaskState.FAILED) { + ExecutionFailureInfo failureInfo = taskStatus.getFailures().stream() + .findFirst() + .map(this::rewriteTransportFailure) + .orElse(toFailure(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"))); + + List replacementTasks = stageExecution.taskFailed(taskId, failureInfo, taskStatus); + replacementTasks.forEach(task -> schedulingQueue.addOrUpdate(task, task.priority())); + + if (shouldDelayScheduling(failureInfo.getErrorCode())) { + schedulingDelayer.startOrProlongDelayIfNecessary(); + scheduledExecutorService.schedule(() -> eventQueue.add(Event.WAKE_UP), schedulingDelayer.getRemainingDelayInMillis(), MILLISECONDS); + } + } + + // update output selectors + ExchangeSourceOutputSelector outputSelector = stageExecution.getSinkOutputSelector(); + for (StageId consumerStageId : stageConsumers.get(stageExecution.getStageId())) { + getStageExecution(consumerStageId).setSourceOutputSelector(stageExecution.getStageFragmentId(), outputSelector); + } + } + + @Override + public void onRemoteTaskExchangeSinkUpdateRequired(RemoteTaskExchangeSinkUpdateRequiredEvent event) + { + TaskId taskId = event.getTaskStatus().getTaskId(); + StageExecution stageExecution = getStageExecution(taskId.getStageId()); + stageExecution.updateExchangeSinkInstanceHandle(taskId); + } + + @Override + public void onPartitionsAdded(PartitionsAddedEvent event) + { + StageId stageId = event.getStageId(); + StageExecution stageExecution = getStageExecution(stageId); + for (Partition partition : event.getPartitions()) { + Optional scheduledTask = stageExecution.addPartition(partition.partitionId(), partition.nodeRequirements()); + scheduledTask.ifPresent(task -> schedulingQueue.addOrUpdate(task, task.priority())); + } + } + + @Override + public void onPartitionsUpdated(PartitionsUpdatedEvent event) + { + StageExecution stageExecution = getStageExecution(event.getStageId()); + for (PartitionUpdate partitionUpdate : event.getPartitionUpdates()) { + stageExecution.updatePartition( + partitionUpdate.partitionId(), + partitionUpdate.planNodeId(), + partitionUpdate.splits(), + partitionUpdate.noMoreSplits()); + } + } + + @Override + public void onPartitionsSealed(PartitionsSealedEvent event) + { + StageId stageId = event.getStageId(); + StageExecution stageExecution = getStageExecution(stageId); + event.getPartitionIds().forEach(partitionId -> { + Optional scheduledTask = stageExecution.sealPartition(partitionId); + scheduledTask.ifPresent(task -> { + if (nodeAcquisitions.containsKey(task)) { + // task is already waiting for node + return; + } + schedulingQueue.addOrUpdate(task, task.priority()); + }); + }); + } + + @Override + public void onNoMorePartitions(NoMorePartitionsEvent event) + { + StageExecution stageExecution = getStageExecution(event.getStageId()); + stageExecution.noMorePartitions(); + } + + @Override + public void onTaskSourceFailure(TaskSourceFailureEvent event) + { + StageExecution stageExecution = getStageExecution(event.getStageId()); + stageExecution.fail(event.getFailure()); + } + + private StageExecution getStageExecution(StageId stageId) + { + StageExecution execution = stageExecutions.get(stageId); + checkState(execution != null, "stage execution does not exist for stage: %s", stageId); + return execution; + } + + private static List sortPlanInTopologicalOrder(SubPlan subPlan) + { + ImmutableList.Builder result = ImmutableList.builder(); + Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder(subPlan).forEach(result::add); + return result.build(); + } + + private boolean shouldDelayScheduling(@Nullable ErrorCode errorCode) + { + return errorCode == null || errorCode.getType() == INTERNAL_ERROR || errorCode.getType() == EXTERNAL; + } + + private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) + { + if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { + return executionFailureInfo; + } + + return new ExecutionFailureInfo( + executionFailureInfo.getType(), + executionFailureInfo.getMessage(), + executionFailureInfo.getCause(), + executionFailureInfo.getSuppressed(), + executionFailureInfo.getStack(), + executionFailureInfo.getErrorLocation(), + REMOTE_HOST_GONE.toErrorCode(), + executionFailureInfo.getRemoteHost()); + } + } + + private static class StageExecution + { + private static final int SPECULATIVE_EXECUTION_PRIORITY = 1_000_000_000; + + private final QueryStateMachine queryStateMachine; + private final TaskDescriptorStorage taskDescriptorStorage; + + private final SqlStage stage; + private final EventDrivenTaskSource taskSource; + private final FaultTolerantPartitioningScheme sinkPartitioningScheme; + private final Exchange exchange; + private final PartitionMemoryEstimator partitionMemoryEstimator; + private final int maxTaskExecutionAttempts; + private final int schedulingPriority; + private final DynamicFilterService dynamicFilterService; + private final long[] outputDataSize; + + private final Int2ObjectMap partitions = new Int2ObjectOpenHashMap<>(); + private boolean noMorePartitions; + + private final IntSet remainingPartitions = new IntOpenHashSet(); + + private ExchangeSourceOutputSelector.Builder sinkOutputSelectorBuilder; + private ExchangeSourceOutputSelector finalSinkOutputSelector; + + private final Set remoteSourceIds; + private final Map remoteSources; + private final Map sourceOutputSelectors = new HashMap<>(); + + private StageExecution( + QueryStateMachine queryStateMachine, + TaskDescriptorStorage taskDescriptorStorage, + SqlStage stage, + EventDrivenTaskSource taskSource, + FaultTolerantPartitioningScheme sinkPartitioningScheme, + Exchange exchange, + PartitionMemoryEstimator partitionMemoryEstimator, + int maxTaskExecutionAttempts, + int schedulingPriority, + DynamicFilterService dynamicFilterService) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); + this.stage = requireNonNull(stage, "stage is null"); + this.taskSource = requireNonNull(taskSource, "taskSource is null"); + this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null"); + this.exchange = requireNonNull(exchange, "exchange is null"); + this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); + this.maxTaskExecutionAttempts = maxTaskExecutionAttempts; + this.schedulingPriority = schedulingPriority; + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + outputDataSize = new long[sinkPartitioningScheme.getPartitionCount()]; + sinkOutputSelectorBuilder = ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchange.getId())); + ImmutableMap.Builder remoteSources = ImmutableMap.builder(); + ImmutableSet.Builder remoteSourceIds = ImmutableSet.builder(); + for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { + remoteSourceIds.add(remoteSource.getId()); + remoteSource.getSourceFragmentIds().forEach(fragmentId -> remoteSources.put(fragmentId, remoteSource)); + } + this.remoteSourceIds = remoteSourceIds.build(); + this.remoteSources = remoteSources.buildOrThrow(); + } + + public StageId getStageId() + { + return stage.getStageId(); + } + + public PlanFragmentId getStageFragmentId() + { + return stage.getFragment().getId(); + } + + public StageState getState() + { + return stage.getState(); + } + + public StageInfo getStageInfo() + { + return stage.getStageInfo(); + } + + public Exchange getExchange() + { + return exchange; + } + + public Optional addPartition(int partitionId, NodeRequirements nodeRequirements) + { + if (getState().isDone()) { + return Optional.empty(); + } + + ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(partitionId); + Session session = queryStateMachine.getSession(); + DataSize defaultTaskMemory = stage.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION) ? + getFaultTolerantExecutionDefaultCoordinatorTaskMemory(session) : + getFaultTolerantExecutionDefaultTaskMemory(session); + StagePartition partition = new StagePartition( + taskDescriptorStorage, + stage.getStageId(), + partitionId, + exchangeSinkHandle, + remoteSourceIds, + nodeRequirements, + partitionMemoryEstimator.getInitialMemoryRequirements(session, defaultTaskMemory), + maxTaskExecutionAttempts); + checkState(partitions.putIfAbsent(partitionId, partition) == null, "partition with id %s already exist in stage %s", partitionId, stage.getStageId()); + getSourceOutputSelectors().forEach((partition::updateExchangeSourceOutputSelector)); + remainingPartitions.add(partitionId); + + return Optional.of(new ScheduledTask(stage.getStageId(), partitionId, SPECULATIVE_EXECUTION_PRIORITY + schedulingPriority)); + } + + public void updatePartition(int partitionId, PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + if (getState().isDone()) { + return; + } + + StagePartition partition = getStagePartition(partitionId); + partition.addSplits(planNodeId, splits, noMoreSplits); + } + + public Optional sealPartition(int partitionId) + { + if (getState().isDone()) { + return Optional.empty(); + } + + StagePartition partition = getStagePartition(partitionId); + partition.seal(partitionId); + + if (!partition.isRunning()) { + // if partition is not yet running update its priority as it is no longer speculative + return Optional.of(new ScheduledTask(stage.getStageId(), partitionId, schedulingPriority)); + } + + // TODO: split into smaller partitions here if necessary (for example if a task for a given partition failed with out of memory) + + return Optional.empty(); + } + + public void noMorePartitions() + { + if (getState().isDone()) { + return; + } + + noMorePartitions = true; + if (remainingPartitions.isEmpty()) { + stage.finish(); + // TODO close exchange early + taskSource.close(); + } + } + + public Optional schedule(int partitionId, InternalNode node) + { + if (getState().isDone()) { + return Optional.empty(); + } + + StagePartition partition = getStagePartition(partitionId); + verify(partition.getRemainingAttempts() >= 0, "remaining attempts is expected to be greater than or equal to zero: %s", partition.getRemainingAttempts()); + + if (partition.isFinished()) { + return Optional.empty(); + } + + Map outputSelectors = getSourceOutputSelectors(); + + ListMultimap splits = ArrayListMultimap.create(); + splits.putAll(partition.getSplits()); + outputSelectors.forEach((planNodeId, outputSelector) -> splits.put(planNodeId, createOutputSelectorSplit(outputSelector))); + + Set noMoreSplits = new HashSet<>(); + for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { + ExchangeSourceOutputSelector selector = outputSelectors.get(remoteSource.getId()); + if (selector != null && selector.isFinal() && partition.isNoMoreSplits(remoteSource.getId())) { + noMoreSplits.add(remoteSource.getId()); + } + } + for (PlanNodeId partitionedSource : stage.getFragment().getPartitionedSources()) { + if (partition.isNoMoreSplits(partitionedSource)) { + noMoreSplits.add(partitionedSource); + } + } + + int attempt = maxTaskExecutionAttempts - partition.getRemainingAttempts(); + ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = exchange.instantiateSink(partition.getExchangeSinkHandle(), attempt); + SpoolingOutputBuffers outputBuffers = SpoolingOutputBuffers.createInitial(exchangeSinkInstanceHandle, sinkPartitioningScheme.getPartitionCount()); + Optional task = stage.createTask( + node, + partitionId, + attempt, + sinkPartitioningScheme.getBucketToPartitionMap(), + outputBuffers, + splits, + noMoreSplits, + Optional.of(partition.getMemoryRequirements().getRequiredMemory())); + task.ifPresent(remoteTask -> partition.addTask(remoteTask, outputBuffers)); + return task; + } + + private Map getSourceOutputSelectors() + { + ImmutableMap.Builder result = ImmutableMap.builder(); + for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { + ExchangeSourceOutputSelector mergedSelector = null; + for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) { + ExchangeSourceOutputSelector sourceFragmentSelector = sourceOutputSelectors.get(sourceFragmentId); + if (sourceFragmentSelector == null) { + continue; + } + if (mergedSelector == null) { + mergedSelector = sourceFragmentSelector; + } + else { + mergedSelector = mergedSelector.merge(sourceFragmentSelector); + } + } + if (mergedSelector != null) { + result.put(remoteSource.getId(), mergedSelector); + } + } + return result.buildOrThrow(); + } + + public void updateExchangeSinkInstanceHandle(TaskId taskId) + { + if (getState().isDone()) { + return; + } + StagePartition partition = getStagePartition(taskId.getPartitionId()); + ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = exchange.updateSinkInstanceHandle(partition.getExchangeSinkHandle(), taskId.getAttemptId()); + partition.updateExchangeSinkInstanceHandle(taskId, exchangeSinkInstanceHandle); + } + + public void taskFinished(TaskId taskId, TaskStatus taskStatus) + { + if (getState().isDone()) { + return; + } + + int partitionId = taskId.getPartitionId(); + StagePartition partition = getStagePartition(partitionId); + exchange.sinkFinished(partition.getExchangeSinkHandle(), taskId.getAttemptId()); + SpoolingOutputStats.Snapshot outputStats = partition.taskFinished(taskId); + + if (!remainingPartitions.remove(partitionId)) { + // a different task for the same partition finished before + return; + } + + updateOutputSize(outputStats); + + partitionMemoryEstimator.registerPartitionFinished( + queryStateMachine.getSession(), + partition.getMemoryRequirements(), + taskStatus.getPeakMemoryReservation(), + true, + Optional.empty()); + + sinkOutputSelectorBuilder.include(exchange.getId(), taskId.getPartitionId(), taskId.getAttemptId()); + + if (noMorePartitions && remainingPartitions.isEmpty() && !stage.getState().isDone()) { + dynamicFilterService.stageCannotScheduleMoreTasks(stage.getStageId(), 0, partitions.size()); + exchange.noMoreSinks(); + exchange.allRequiredSinksFinished(); + verify(finalSinkOutputSelector == null, "finalOutputSelector is already set"); + sinkOutputSelectorBuilder.setPartitionCount(exchange.getId(), partitions.size()); + sinkOutputSelectorBuilder.setFinal(); + finalSinkOutputSelector = sinkOutputSelectorBuilder.build(); + sinkOutputSelectorBuilder = null; + stage.finish(); + } + } + + private void updateOutputSize(SpoolingOutputStats.Snapshot taskOutputStats) + { + for (int partitionId = 0; partitionId < sinkPartitioningScheme.getPartitionCount(); partitionId++) { + long partitionSizeInBytes = taskOutputStats.getPartitionSizeInBytes(partitionId); + checkArgument(partitionSizeInBytes >= 0, "partitionSizeInBytes must be greater than or equal to zero: %s", partitionSizeInBytes); + outputDataSize[partitionId] += partitionSizeInBytes; + } + } + + public List taskFailed(TaskId taskId, ExecutionFailureInfo failureInfo, TaskStatus taskStatus) + { + if (getState().isDone()) { + return ImmutableList.of(); + } + + int partitionId = taskId.getPartitionId(); + StagePartition partition = getStagePartition(partitionId); + partition.taskFailed(taskId); + + RuntimeException failure = failureInfo.toException(); + ErrorCode errorCode = failureInfo.getErrorCode(); + partitionMemoryEstimator.registerPartitionFinished( + queryStateMachine.getSession(), + partition.getMemoryRequirements(), + taskStatus.getPeakMemoryReservation(), + false, + Optional.ofNullable(errorCode)); + + // update memory limits for next attempt + MemoryRequirements currentMemoryLimits = partition.getMemoryRequirements(); + MemoryRequirements newMemoryLimits = partitionMemoryEstimator.getNextRetryMemoryRequirements( + queryStateMachine.getSession(), + partition.getMemoryRequirements(), + taskStatus.getPeakMemoryReservation(), + errorCode); + partition.setMemoryRequirements(newMemoryLimits); + log.debug( + "Computed next memory requirements for task from stage %s; previous=%s; new=%s; peak=%s; estimator=%s", + stage.getStageId(), + currentMemoryLimits, + newMemoryLimits, + taskStatus.getPeakMemoryReservation(), + partitionMemoryEstimator); + + if (errorCode != null && isOutOfMemoryError(errorCode) && newMemoryLimits.getRequiredMemory().toBytes() * 0.99 <= taskStatus.getPeakMemoryReservation().toBytes()) { + String message = format( + "Cannot allocate enough memory for task %s. Reported peak memory reservation: %s. Maximum possible reservation: %s.", + taskId, + taskStatus.getPeakMemoryReservation(), + newMemoryLimits.getRequiredMemory()); + stage.fail(new TrinoException(() -> errorCode, message, failure)); + return ImmutableList.of(); + } + + if (partition.getRemainingAttempts() == 0 || (errorCode != null && errorCode.getType() == USER_ERROR)) { + stage.fail(failure); + // stage failed, don't reschedule + return ImmutableList.of(); + } + + if (!partition.isSealed()) { + // don't reschedule speculative tasks + return ImmutableList.of(); + } + + // TODO: split into smaller partitions here if necessary (for example if a task for a given partition failed with out of memory) + + // reschedule a task + return ImmutableList.of(new ScheduledTask(stage.getStageId(), partitionId, schedulingPriority)); + } + + public MemoryRequirements getMemoryRequirements(int partitionId) + { + return getStagePartition(partitionId).getMemoryRequirements(); + } + + public Optional getNodeRequirements(int partitionId) + { + return getStagePartition(partitionId).getNodeRequirements(); + } + + public OutputDataSizeEstimate getOutputDataSize() + { + // TODO enable speculative execution + checkState(stage.getState() == StageState.FINISHED, "stage %s is expected to be in FINISHED state, got %s", stage.getStageId(), stage.getState()); + return new OutputDataSizeEstimate(ImmutableLongArray.copyOf(outputDataSize)); + } + + public ExchangeSourceOutputSelector getSinkOutputSelector() + { + if (finalSinkOutputSelector != null) { + return finalSinkOutputSelector; + } + return sinkOutputSelectorBuilder.build(); + } + + public void setSourceOutputSelector(PlanFragmentId sourceFragmentId, ExchangeSourceOutputSelector selector) + { + sourceOutputSelectors.put(sourceFragmentId, selector); + RemoteSourceNode remoteSourceNode = remoteSources.get(sourceFragmentId); + verify(remoteSourceNode != null, "remoteSourceNode is null for fragment: %s", sourceFragmentId); + ExchangeSourceOutputSelector mergedSelector = selector; + for (PlanFragmentId fragmentId : remoteSourceNode.getSourceFragmentIds()) { + if (fragmentId.equals(sourceFragmentId)) { + continue; + } + ExchangeSourceOutputSelector fragmentSelector = sourceOutputSelectors.get(fragmentId); + if (fragmentSelector != null) { + mergedSelector = mergedSelector.merge(fragmentSelector); + } + } + ExchangeSourceOutputSelector finalMergedSelector = mergedSelector; + remainingPartitions.forEach((java.util.function.IntConsumer) value -> { + StagePartition partition = partitions.get(value); + verify(partition != null, "partition not found: %s", value); + partition.updateExchangeSourceOutputSelector(remoteSourceNode.getId(), finalMergedSelector); + }); + } + + public void abort() + { + Closer closer = createStageExecutionCloser(); + closer.register(stage::abort); + try { + closer.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void fail(Throwable t) + { + Closer closer = createStageExecutionCloser(); + closer.register(() -> stage.fail(t)); + try { + closer.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private Closer createStageExecutionCloser() + { + Closer closer = Closer.create(); + closer.register(taskSource); + closer.register(exchange); + return closer; + } + + private StagePartition getStagePartition(int partitionId) + { + StagePartition partition = partitions.get(partitionId); + checkState(partition != null, "partition with id %s does not exist in stage %s", partitionId, stage.getStageId()); + return partition; + } + } + + private static class StagePartition + { + private final TaskDescriptorStorage taskDescriptorStorage; + private final StageId stageId; + private final int partitionId; + private final ExchangeSinkHandle exchangeSinkHandle; + private final Set remoteSourceIds; + + // empty when task descriptor is closed and stored in TaskDescriptorStorage + private Optional openTaskDescriptor; + private MemoryRequirements memoryRequirements; + private int remainingAttempts; + + private final Map tasks = new HashMap<>(); + private final Map taskOutputBuffers = new HashMap<>(); + private final Set runningTasks = new HashSet<>(); + private final Set finalSelectors = new HashSet<>(); + private final Set noMoreSplits = new HashSet<>(); + private boolean finished; + + public StagePartition( + TaskDescriptorStorage taskDescriptorStorage, + StageId stageId, + int partitionId, + ExchangeSinkHandle exchangeSinkHandle, + Set remoteSourceIds, + NodeRequirements nodeRequirements, + MemoryRequirements memoryRequirements, + int maxTaskExecutionAttempts) + { + this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); + this.stageId = requireNonNull(stageId, "stageId is null"); + this.partitionId = partitionId; + this.exchangeSinkHandle = requireNonNull(exchangeSinkHandle, "exchangeSinkHandle is null"); + this.remoteSourceIds = ImmutableSet.copyOf(requireNonNull(remoteSourceIds, "remoteSourceIds is null")); + requireNonNull(nodeRequirements, "nodeRequirements is null"); + this.openTaskDescriptor = Optional.of(new OpenTaskDescriptor(ImmutableListMultimap.of(), ImmutableSet.of(), nodeRequirements)); + this.memoryRequirements = requireNonNull(memoryRequirements, "memoryRequirements is null"); + this.remainingAttempts = maxTaskExecutionAttempts; + } + + public int getPartitionId() + { + return partitionId; + } + + public ExchangeSinkHandle getExchangeSinkHandle() + { + return exchangeSinkHandle; + } + + public void addSplits(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + checkState(openTaskDescriptor.isPresent(), "openTaskDescriptor is empty"); + openTaskDescriptor = Optional.of(openTaskDescriptor.get().update(planNodeId, splits, noMoreSplits)); + if (noMoreSplits) { + this.noMoreSplits.add(planNodeId); + } + for (RemoteTask task : tasks.values()) { + task.addSplits(ImmutableListMultimap.builder() + .putAll(planNodeId, splits) + .build()); + if (noMoreSplits && isFinalOutputSelectorDelivered(planNodeId)) { + task.noMoreSplits(planNodeId); + } + } + } + + private boolean isFinalOutputSelectorDelivered(PlanNodeId planNodeId) + { + if (!remoteSourceIds.contains(planNodeId)) { + // not a remote source; input selector concept not applicable + return true; + } + return finalSelectors.contains(planNodeId); + } + + public void seal(int partitionId) + { + checkState(openTaskDescriptor.isPresent(), "openTaskDescriptor is empty"); + TaskDescriptor taskDescriptor = openTaskDescriptor.get().createTaskDescriptor(partitionId); + openTaskDescriptor = Optional.empty(); + // a task may finish before task descriptor is sealed + if (!finished) { + taskDescriptorStorage.put(stageId, taskDescriptor); + } + } + + public ListMultimap getSplits() + { + if (finished) { + return ImmutableListMultimap.of(); + } + return openTaskDescriptor.map(OpenTaskDescriptor::getSplits) + .or(() -> taskDescriptorStorage.get(stageId, partitionId).map(TaskDescriptor::getSplits)) + // execution is finished + .orElse(ImmutableListMultimap.of()); + } + + public boolean isNoMoreSplits(PlanNodeId planNodeId) + { + if (finished) { + return true; + } + return openTaskDescriptor.map(taskDescriptor -> taskDescriptor.getNoMoreSplits().contains(planNodeId)) + // task descriptor is sealed, no more splits are expected + .orElse(true); + } + + public boolean isSealed() + { + return openTaskDescriptor.isEmpty(); + } + + /** + * Returns {@link Optional#empty()} when execution is finished + */ + public Optional getNodeRequirements() + { + if (finished) { + return Optional.empty(); + } + if (openTaskDescriptor.isPresent()) { + return openTaskDescriptor.map(OpenTaskDescriptor::getNodeRequirements); + } + Optional taskDescriptor = taskDescriptorStorage.get(stageId, partitionId); + if (taskDescriptor.isPresent()) { + return taskDescriptor.map(TaskDescriptor::getNodeRequirements); + } + return Optional.empty(); + } + + public MemoryRequirements getMemoryRequirements() + { + return memoryRequirements; + } + + public void setMemoryRequirements(MemoryRequirements memoryRequirements) + { + this.memoryRequirements = requireNonNull(memoryRequirements, "memoryRequirements is null"); + } + + public int getRemainingAttempts() + { + return remainingAttempts; + } + + public void addTask(RemoteTask remoteTask, SpoolingOutputBuffers outputBuffers) + { + TaskId taskId = remoteTask.getTaskId(); + tasks.put(taskId, remoteTask); + taskOutputBuffers.put(taskId, outputBuffers); + runningTasks.add(taskId); + } + + public SpoolingOutputStats.Snapshot taskFinished(TaskId taskId) + { + RemoteTask remoteTask = tasks.get(taskId); + checkArgument(remoteTask != null, "task not found: %s", taskId); + SpoolingOutputStats.Snapshot outputStats = remoteTask.retrieveAndDropSpoolingOutputStats(); + runningTasks.remove(taskId); + tasks.values().forEach(RemoteTask::abort); + finished = true; + // task descriptor has been created + if (isSealed()) { + taskDescriptorStorage.remove(stageId, partitionId); + } + return outputStats; + } + + public void taskFailed(TaskId taskId) + { + runningTasks.remove(taskId); + remainingAttempts--; + } + + public void updateExchangeSinkInstanceHandle(TaskId taskId, ExchangeSinkInstanceHandle handle) + { + SpoolingOutputBuffers outputBuffers = taskOutputBuffers.get(taskId); + checkArgument(outputBuffers != null, "output buffers not found: %s", taskId); + RemoteTask remoteTask = tasks.get(taskId); + checkArgument(remoteTask != null, "task not found: %s", taskId); + SpoolingOutputBuffers updatedOutputBuffers = outputBuffers.withExchangeSinkInstanceHandle(handle); + taskOutputBuffers.put(taskId, updatedOutputBuffers); + remoteTask.setOutputBuffers(updatedOutputBuffers); + } + + public void updateExchangeSourceOutputSelector(PlanNodeId planNodeId, ExchangeSourceOutputSelector selector) + { + if (selector.isFinal()) { + finalSelectors.add(planNodeId); + } + for (TaskId taskId : runningTasks) { + RemoteTask task = tasks.get(taskId); + verify(task != null, "task is null: %s", taskId); + task.addSplits(ImmutableListMultimap.of( + planNodeId, + createOutputSelectorSplit(selector))); + if (selector.isFinal() && noMoreSplits.contains(planNodeId)) { + task.noMoreSplits(planNodeId); + } + } + } + + public boolean isRunning() + { + return !runningTasks.isEmpty(); + } + + public boolean isFinished() + { + return finished; + } + } + + private static Split createOutputSelectorSplit(ExchangeSourceOutputSelector selector) + { + return new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(), Optional.of(selector)))); + } + + private static class OpenTaskDescriptor + { + private final ListMultimap splits; + private final Set noMoreSplits; + private final NodeRequirements nodeRequirements; + + private OpenTaskDescriptor(ListMultimap splits, Set noMoreSplits, NodeRequirements nodeRequirements) + { + this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); + this.noMoreSplits = ImmutableSet.copyOf(requireNonNull(noMoreSplits, "noMoreSplits is null")); + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + } + + public ListMultimap getSplits() + { + return splits; + } + + public Set getNoMoreSplits() + { + return noMoreSplits; + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + public OpenTaskDescriptor update(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + ListMultimap updatedSplits = ImmutableListMultimap.builder() + .putAll(this.splits) + .putAll(planNodeId, splits) + .build(); + + Set updatedNoMoreSplits = this.noMoreSplits; + if (noMoreSplits && !updatedNoMoreSplits.contains(planNodeId)) { + updatedNoMoreSplits = ImmutableSet.builder() + .addAll(this.noMoreSplits) + .add(planNodeId) + .build(); + } + return new OpenTaskDescriptor( + updatedSplits, + updatedNoMoreSplits, + nodeRequirements); + } + + public TaskDescriptor createTaskDescriptor(int partitionId) + { + Set missingNoMoreSplits = Sets.difference(splits.keySet(), noMoreSplits); + checkState(missingNoMoreSplits.isEmpty(), "missing no more splits for plan nodes: %s", missingNoMoreSplits); + return new TaskDescriptor( + partitionId, + splits, + nodeRequirements); + } + } + + private record ScheduledTask(StageId stageId, int partitionId, int priority) + { + public ScheduledTask + { + requireNonNull(stageId, "stageId is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ScheduledTask that = (ScheduledTask) o; + return partitionId == that.partitionId && Objects.equals(stageId, that.stageId); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, partitionId); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("stageId", stageId) + .add("partitionId", partitionId) + .add("priority", priority) + .toString(); + } + } + + private static class SchedulingDelayer + { + private final long minRetryDelayInMillis; + private final long maxRetryDelayInMillis; + private final double retryDelayScaleFactor; + private final Stopwatch stopwatch; + + private long currentDelayInMillis; + + private SchedulingDelayer(Duration minRetryDelay, Duration maxRetryDelay, double retryDelayScaleFactor, Stopwatch stopwatch) + { + this.minRetryDelayInMillis = requireNonNull(minRetryDelay, "minRetryDelay is null").toMillis(); + this.maxRetryDelayInMillis = requireNonNull(maxRetryDelay, "maxRetryDelay is null").toMillis(); + checkArgument(retryDelayScaleFactor >= 1, "retryDelayScaleFactor is expected to be greater than or equal to 1: %s", retryDelayScaleFactor); + this.retryDelayScaleFactor = retryDelayScaleFactor; + this.stopwatch = requireNonNull(stopwatch, "stopwatch is null"); + } + + public void startOrProlongDelayIfNecessary() + { + if (stopwatch.isRunning()) { + if (stopwatch.elapsed(MILLISECONDS) > currentDelayInMillis) { + // we are past previous delay period and still getting failures; let's make it longer + stopwatch.reset().start(); + currentDelayInMillis = min(round(currentDelayInMillis * retryDelayScaleFactor), maxRetryDelayInMillis); + } + } + else { + // initialize delaying of tasks scheduling + stopwatch.start(); + currentDelayInMillis = minRetryDelayInMillis; + } + } + + public long getRemainingDelayInMillis() + { + if (stopwatch.isRunning()) { + return max(0, currentDelayInMillis - stopwatch.elapsed(MILLISECONDS)); + } + return 0; + } + } + + private interface Event + { + Event ABORT = listener -> { + throw new UnsupportedOperationException(); + }; + + Event WAKE_UP = listener -> { + throw new UnsupportedOperationException(); + }; + + void accept(EventListener listener); + } + + private interface EventListener + { + void onRemoteTaskCompleted(RemoteTaskCompletedEvent event); + + void onRemoteTaskExchangeSinkUpdateRequired(RemoteTaskExchangeSinkUpdateRequiredEvent event); + + void onPartitionsAdded(PartitionsAddedEvent event); + + void onPartitionsUpdated(PartitionsUpdatedEvent event); + + void onPartitionsSealed(PartitionsSealedEvent event); + + void onNoMorePartitions(NoMorePartitionsEvent event); + + void onTaskSourceFailure(TaskSourceFailureEvent event); + } + + private static class RemoteTaskCompletedEvent + extends RemoteTaskEvent + { + public RemoteTaskCompletedEvent(TaskStatus taskStatus) + { + super(taskStatus); + } + + @Override + public void accept(EventListener listener) + { + listener.onRemoteTaskCompleted(this); + } + } + + private static class RemoteTaskExchangeSinkUpdateRequiredEvent + extends RemoteTaskEvent + { + protected RemoteTaskExchangeSinkUpdateRequiredEvent(TaskStatus taskStatus) + { + super(taskStatus); + } + + @Override + public void accept(EventListener listener) + { + listener.onRemoteTaskExchangeSinkUpdateRequired(this); + } + } + + private abstract static class RemoteTaskEvent + implements Event + { + private final TaskStatus taskStatus; + + protected RemoteTaskEvent(TaskStatus taskStatus) + { + this.taskStatus = requireNonNull(taskStatus, "taskStatus is null"); + } + + public TaskStatus getTaskStatus() + { + return taskStatus; + } + } + + private static class PartitionsAddedEvent + extends TaskSourceEvent + { + private final List partitions; + + public PartitionsAddedEvent(StageId stageId, List partitions) + { + super(stageId); + this.partitions = ImmutableList.copyOf(requireNonNull(partitions, "partitions is null")); + } + + public List getPartitions() + { + return partitions; + } + + @Override + public void accept(EventListener listener) + { + listener.onPartitionsAdded(this); + } + } + + private static class NoMorePartitionsEvent + extends TaskSourceEvent + { + public NoMorePartitionsEvent(StageId stageId) + { + super(stageId); + } + + @Override + public void accept(EventListener listener) + { + listener.onNoMorePartitions(this); + } + } + + private static class PartitionsUpdatedEvent + extends TaskSourceEvent + { + private final List partitionUpdates; + + public PartitionsUpdatedEvent(StageId stageId, List partitionUpdates) + { + super(stageId); + this.partitionUpdates = ImmutableList.copyOf(requireNonNull(partitionUpdates, "partitionUpdates is null")); + } + + public List getPartitionUpdates() + { + return partitionUpdates; + } + + @Override + public void accept(EventListener listener) + { + listener.onPartitionsUpdated(this); + } + } + + private static class PartitionsSealedEvent + extends TaskSourceEvent + { + private final ImmutableIntArray partitionIds; + + public PartitionsSealedEvent(StageId stageId, ImmutableIntArray partitionIds) + { + super(stageId); + this.partitionIds = requireNonNull(partitionIds, "partitionIds is null"); + } + + public ImmutableIntArray getPartitionIds() + { + return partitionIds; + } + + @Override + public void accept(EventListener listener) + { + listener.onPartitionsSealed(this); + } + } + + private static class TaskSourceFailureEvent + extends TaskSourceEvent + { + private final Throwable failure; + + public TaskSourceFailureEvent(StageId stageId, Throwable failure) + { + super(stageId); + this.failure = requireNonNull(failure, "failure is null"); + } + + public Throwable getFailure() + { + return failure; + } + + @Override + public void accept(EventListener listener) + { + listener.onTaskSourceFailure(this); + } + } + + private abstract static class TaskSourceEvent + implements Event + { + private final StageId stageId; + + protected TaskSourceEvent(StageId stageId) + { + this.stageId = requireNonNull(stageId, "stageId is null"); + } + + public StageId getStageId() + { + return stageId; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSource.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSource.java new file mode 100644 index 000000000000..d6ffa62e943f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSource.java @@ -0,0 +1,517 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.SetMultimap; +import com.google.common.io.Closer; +import com.google.common.primitives.ImmutableIntArray; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.connector.CatalogHandle; +import io.trino.exchange.SpoolingExchangeInput; +import io.trino.metadata.Split; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceHandleSource; +import io.trino.spi.exchange.ExchangeSourceHandleSource.ExchangeSourceHandleBatch; +import io.trino.split.RemoteSplit; +import io.trino.split.SplitSource; +import io.trino.split.SplitSource.SplitBatch; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.function.LongConsumer; +import java.util.function.Supplier; +import java.util.function.ToIntFunction; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; +import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +class EventDrivenTaskSource + implements Closeable +{ + private final Map sourceExchanges; + private final Map remoteSources; + private final Supplier> splitSourceSupplier; + @GuardedBy("assignerLock") + private final SplitAssigner assigner; + @GuardedBy("assignerLock") + private final Callback callback; + private final Executor executor; + private final int splitBatchSize; + private final long targetExchangeSplitSizeInBytes; + private final FaultTolerantPartitioningScheme sourcePartitioningScheme; + private final LongConsumer getSplitTimeRecorder; + private final SetMultimap remoteSourceFragments; + + @GuardedBy("this") + private boolean started; + @GuardedBy("this") + private boolean closed; + @GuardedBy("this") + private final Closer closer = Closer.create(); + + private final Object assignerLock = new Object(); + + @GuardedBy("assignerLock") + private final Set finishedFragments = new HashSet<>(); + @GuardedBy("assignerLock") + private final Set allSources = new HashSet<>(); + @GuardedBy("assignerLock") + private final Set finishedSources = new HashSet<>(); + + EventDrivenTaskSource( + Map sourceExchanges, + Map remoteSources, + Supplier> splitSourceSupplier, + SplitAssigner assigner, + Callback callback, + Executor executor, + int splitBatchSize, + long targetExchangeSplitSizeInBytes, + FaultTolerantPartitioningScheme sourcePartitioningScheme, + LongConsumer getSplitTimeRecorder) + { + this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null")); + this.remoteSources = ImmutableMap.copyOf(requireNonNull(remoteSources, "remoteSources is null")); + checkArgument( + sourceExchanges.keySet().equals(remoteSources.keySet()), + "sourceExchanges and remoteSources are expected to contain the same set of keys: %s != %s", + sourceExchanges.keySet(), + remoteSources.keySet()); + this.splitSourceSupplier = requireNonNull(splitSourceSupplier, "splitSourceSupplier is null"); + this.assigner = requireNonNull(assigner, "assigner is null"); + this.callback = requireNonNull(callback, "callback is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.splitBatchSize = splitBatchSize; + this.targetExchangeSplitSizeInBytes = targetExchangeSplitSizeInBytes; + this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); + this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); + remoteSourceFragments = remoteSources.entrySet().stream() + .collect(toImmutableSetMultimap(Map.Entry::getValue, Map.Entry::getKey)); + } + + public synchronized void start() + { + checkState(!started, "already started"); + checkState(!closed, "already closed"); + started = true; + try { + List splitLoaders = new ArrayList<>(); + for (Map.Entry entry : sourceExchanges.entrySet()) { + PlanFragmentId fragmentId = entry.getKey(); + PlanNodeId remoteSourceNodeId = getRemoteSourceNode(fragmentId); + // doesn't have to be synchronized by assignerLock until the loaders are started + allSources.add(remoteSourceNodeId); + ExchangeSourceHandleSource handleSource = closer.register(entry.getValue().getSourceHandles()); + ExchangeSplitSource splitSource = closer.register(new ExchangeSplitSource(handleSource, targetExchangeSplitSizeInBytes)); + SplitLoader splitLoader = closer.register(createExchangeSplitLoader(fragmentId, remoteSourceNodeId, splitSource)); + splitLoaders.add(splitLoader); + } + for (Map.Entry entry : splitSourceSupplier.get().entrySet()) { + PlanNodeId planNodeId = entry.getKey(); + // doesn't have to be synchronized by assignerLock until the loaders are started + allSources.add(planNodeId); + SplitLoader splitLoader = closer.register(createTableScanSplitLoader(planNodeId, entry.getValue())); + splitLoaders.add(splitLoader); + } + if (splitLoaders.isEmpty()) { + executor.execute(() -> { + try { + synchronized (assignerLock) { + assigner.finish().update(callback); + } + } + catch (Throwable t) { + fail(t); + } + }); + } + else { + splitLoaders.forEach(SplitLoader::start); + } + } + catch (Throwable t) { + try { + closer.close(); + } + catch (Throwable closerFailure) { + if (closerFailure != t) { + t.addSuppressed(closerFailure); + } + } + throw t; + } + } + + private SplitLoader createExchangeSplitLoader(PlanFragmentId fragmentId, PlanNodeId remoteSourceNodeId, ExchangeSplitSource splitSource) + { + return new SplitLoader( + splitSource, + executor, + ExchangeSplitSource::getSplitPartition, + new SplitLoader.Callback() + { + @Override + public void update(ListMultimap splits, boolean noMoreSplitsForFragment) + { + try { + synchronized (assignerLock) { + if (noMoreSplitsForFragment) { + finishedFragments.add(fragmentId); + } + boolean noMoreSplitsForRemoteSource = finishedFragments.containsAll(remoteSourceFragments.get(remoteSourceNodeId)); + assigner.assign(remoteSourceNodeId, splits, noMoreSplitsForRemoteSource).update(callback); + if (noMoreSplitsForRemoteSource) { + finishedSources.add(remoteSourceNodeId); + } + if (finishedSources.containsAll(allSources)) { + assigner.finish().update(callback); + } + } + } + catch (Throwable t) { + fail(t); + } + } + + @Override + public void failed(Throwable t) + { + fail(t); + } + }, + splitBatchSize, + getSplitTimeRecorder); + } + + private SplitLoader createTableScanSplitLoader(PlanNodeId planNodeId, SplitSource splitSource) + { + return new SplitLoader( + splitSource, + executor, + this::getSplitPartition, + new SplitLoader.Callback() + { + @Override + public void update(ListMultimap splits, boolean noMoreSplits) + { + try { + synchronized (assignerLock) { + assigner.assign(planNodeId, splits, noMoreSplits).update(callback); + if (noMoreSplits) { + finishedSources.add(planNodeId); + } + if (finishedSources.containsAll(allSources)) { + assigner.finish().update(callback); + } + } + } + catch (Throwable t) { + fail(t); + } + } + + @Override + public void failed(Throwable t) + { + fail(t); + } + }, + splitBatchSize, + getSplitTimeRecorder); + } + + private PlanNodeId getRemoteSourceNode(PlanFragmentId fragmentId) + { + PlanNodeId planNodeId = remoteSources.get(fragmentId); + verify(planNodeId != null, "remote source not found for fragment: %s", fragmentId); + return planNodeId; + } + + private int getSplitPartition(Split split) + { + return sourcePartitioningScheme.getPartition(split); + } + + private void fail(Throwable failure) + { + callback.failed(failure); + close(); + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + try { + closer.close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + interface Callback + { + void partitionsAdded(List partitions); + + void noMorePartitions(); + + void partitionsUpdated(List partitionUpdates); + + void partitionsSealed(ImmutableIntArray partitionIds); + + void failed(Throwable t); + } + + record Partition(int partitionId, NodeRequirements nodeRequirements) + { + public Partition + { + requireNonNull(nodeRequirements, "nodeRequirements is null"); + } + } + + record PartitionUpdate(int partitionId, PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + public PartitionUpdate + { + requireNonNull(planNodeId, "planNodeId is null"); + splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + } + } + + private static class ExchangeSplitSource + implements SplitSource + { + private final ExchangeSourceHandleSource handleSource; + private final long targetSplitSizeInBytes; + private final AtomicBoolean finished = new AtomicBoolean(); + + private ExchangeSplitSource(ExchangeSourceHandleSource handleSource, long targetSplitSizeInBytes) + { + this.handleSource = requireNonNull(handleSource, "handleSource is null"); + this.targetSplitSizeInBytes = targetSplitSizeInBytes; + } + + @Override + public CatalogHandle getCatalogHandle() + { + return REMOTE_CATALOG_HANDLE; + } + + @Override + public ListenableFuture getNextBatch(int maxSize) + { + ListenableFuture sourceHandlesFuture = toListenableFuture(handleSource.getNextBatch()); + return Futures.transform( + sourceHandlesFuture, + batch -> { + List handles = batch.handles(); + ListMultimap partitionToHandles = handles.stream() + .collect(toImmutableListMultimap(ExchangeSourceHandle::getPartitionId, Function.identity())); + ImmutableList.Builder splits = ImmutableList.builder(); + for (int partition : partitionToHandles.keySet()) { + splits.addAll(createRemoteSplits(partitionToHandles.get(partition))); + } + if (batch.lastBatch()) { + finished.set(true); + } + return new SplitBatch(splits.build(), batch.lastBatch()); + }, directExecutor()); + } + + private List createRemoteSplits(List handles) + { + ImmutableList.Builder result = ImmutableList.builder(); + ImmutableList.Builder currentSplitHandles = ImmutableList.builder(); + long currentSplitHandlesSize = 0; + long currentSplitHandlesCount = 0; + for (ExchangeSourceHandle handle : handles) { + if (currentSplitHandlesCount > 0 && currentSplitHandlesSize + handle.getDataSizeInBytes() > targetSplitSizeInBytes) { + result.add(createRemoteSplit(currentSplitHandles.build())); + currentSplitHandles = ImmutableList.builder(); + currentSplitHandlesSize = 0; + currentSplitHandlesCount = 0; + } + currentSplitHandles.add(handle); + currentSplitHandlesSize += handle.getDataSizeInBytes(); + currentSplitHandlesCount++; + } + if (currentSplitHandlesCount > 0) { + result.add(createRemoteSplit(currentSplitHandles.build())); + } + return result.build(); + } + + private static Split createRemoteSplit(List handles) + { + return new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(handles, Optional.empty()))); + } + + private static int getSplitPartition(Split split) + { + RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); + SpoolingExchangeInput exchangeInput = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); + List handles = exchangeInput.getExchangeSourceHandles(); + return handles.get(0).getPartitionId(); + } + + @Override + public void close() + { + handleSource.close(); + } + + @Override + public boolean isFinished() + { + return finished.get(); + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + return Optional.empty(); + } + } + + private static class SplitLoader + implements Closeable + { + private final SplitSource splitSource; + private final Executor executor; + private final ToIntFunction splitToPartition; + private final Callback callback; + private final int splitBatchSize; + private final LongConsumer getSplitTimeRecorder; + + @GuardedBy("this") + private boolean started; + @GuardedBy("this") + private boolean closed; + @GuardedBy("this") + private ListenableFuture splitLoadingFuture; + + public SplitLoader( + SplitSource splitSource, + Executor executor, + ToIntFunction splitToPartition, + Callback callback, + int splitBatchSize, + LongConsumer getSplitTimeRecorder) + { + this.splitSource = requireNonNull(splitSource, "splitSource is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.splitToPartition = requireNonNull(splitToPartition, "splitToPartition is null"); + this.callback = requireNonNull(callback, "callback is null"); + this.splitBatchSize = splitBatchSize; + this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); + } + + public synchronized void start() + { + checkState(!started, "already started"); + checkState(!closed, "already closed"); + started = true; + processNext(); + } + + private synchronized void processNext() + { + if (closed) { + return; + } + verify(splitLoadingFuture == null || splitLoadingFuture.isDone(), "splitLoadingFuture is still running"); + long start = System.currentTimeMillis(); + splitLoadingFuture = splitSource.getNextBatch(splitBatchSize); + Futures.addCallback(splitLoadingFuture, new FutureCallback<>() + { + @Override + public void onSuccess(SplitBatch result) + { + try { + getSplitTimeRecorder.accept(System.currentTimeMillis() - start); + ListMultimap splits = result.getSplits().stream() + .collect(toImmutableListMultimap(splitToPartition::applyAsInt, Function.identity())); + callback.update(splits, result.isLastBatch()); + if (!result.isLastBatch()) { + processNext(); + } + } + catch (Throwable t) { + callback.failed(t); + } + } + + @Override + public void onFailure(Throwable t) + { + callback.failed(t); + } + }, executor); + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + if (splitLoadingFuture != null) { + splitLoadingFuture.cancel(true); + splitLoadingFuture = null; + } + splitSource.close(); + } + + public interface Callback + { + void update(ListMultimap splits, boolean noMoreSplits); + + void failed(Throwable t); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java new file mode 100644 index 000000000000..a0a873bf2142 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java @@ -0,0 +1,220 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.execution.ForQueryExecution; +import io.trino.execution.QueryManagerConfig; +import io.trino.execution.scheduler.EventDrivenTaskSource.Callback; +import io.trino.metadata.InternalNodeManager; +import io.trino.spi.HostAddress; +import io.trino.spi.Node; +import io.trino.spi.exchange.Exchange; +import io.trino.sql.planner.MergePartitioningHandle; +import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.SplitSourceFactory; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.PlanVisitor; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.TableWriterNode; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.function.LongConsumer; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount; +import static io.trino.SystemSessionProperties.getFaultTolerantPreserveInputPartitionsInWriteStage; +import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static java.util.Objects.requireNonNull; + +public class EventDrivenTaskSourceFactory +{ + private final SplitSourceFactory splitSourceFactory; + private final Executor executor; + private final InternalNodeManager nodeManager; + private final int splitBatchSize; + + @Inject + public EventDrivenTaskSourceFactory( + SplitSourceFactory splitSourceFactory, + @ForQueryExecution ExecutorService executor, + InternalNodeManager nodeManager, + QueryManagerConfig queryManagerConfig) + { + this( + splitSourceFactory, + executor, + nodeManager, + requireNonNull(queryManagerConfig, "queryManagerConfig is null").getScheduleSplitBatchSize()); + } + + public EventDrivenTaskSourceFactory( + SplitSourceFactory splitSourceFactory, + Executor executor, + InternalNodeManager nodeManager, + int splitBatchSize) + { + this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.splitBatchSize = splitBatchSize; + } + + public EventDrivenTaskSource create( + Callback callback, + Session session, + PlanFragment fragment, + Map sourceExchanges, + FaultTolerantPartitioningScheme sourcePartitioningScheme, + LongConsumer getSplitTimeRecorder, + Map outputDataSizeEstimates) + { + ImmutableMap.Builder remoteSources = ImmutableMap.builder(); + for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) { + for (PlanFragmentId sourceFragment : remoteSource.getSourceFragmentIds()) { + remoteSources.put(sourceFragment, remoteSource.getId()); + } + } + long targetPartitionSizeInBytes = getFaultTolerantExecutionTargetTaskInputSize(session).toBytes(); + // TODO: refactor to define explicitly + long standardSplitSizeInBytes = targetPartitionSizeInBytes / getFaultTolerantExecutionTargetTaskSplitCount(session); + int maxTaskSplitCount = getFaultTolerantExecutionMaxTaskSplitCount(session); + return new EventDrivenTaskSource( + sourceExchanges, + remoteSources.buildOrThrow(), + () -> splitSourceFactory.createSplitSources(session, fragment), + createSplitAssigner( + session, + fragment, + outputDataSizeEstimates, + sourcePartitioningScheme, + targetPartitionSizeInBytes, + standardSplitSizeInBytes, + maxTaskSplitCount), + callback, + executor, + splitBatchSize, + standardSplitSizeInBytes, + sourcePartitioningScheme, + getSplitTimeRecorder); + } + + private SplitAssigner createSplitAssigner( + Session session, + PlanFragment fragment, + Map outputDataSizeEstimates, + FaultTolerantPartitioningScheme sourcePartitioningScheme, + long targetPartitionSizeInBytes, + long standardSplitSizeInBytes, + int maxArbitraryDistributionTaskSplitCount) + { + PartitioningHandle partitioning = fragment.getPartitioning(); + + Set partitionedRemoteSources = fragment.getRemoteSourceNodes().stream() + .filter(node -> node.getExchangeType() != REPLICATE) + .map(PlanNode::getId) + .collect(toImmutableSet()); + Set partitionedSources = ImmutableSet.builder() + .addAll(partitionedRemoteSources) + .addAll(fragment.getPartitionedSources()) + .build(); + Set replicatedSources = fragment.getRemoteSourceNodes().stream() + .filter(node -> node.getExchangeType() == REPLICATE) + .map(PlanNode::getId) + .collect(toImmutableSet()); + + boolean coordinatorOnly = partitioning.equals(COORDINATOR_DISTRIBUTION); + if (partitioning.equals(SINGLE_DISTRIBUTION) || coordinatorOnly) { + ImmutableSet hostRequirement = ImmutableSet.of(); + if (coordinatorOnly) { + Node currentNode = nodeManager.getCurrentNode(); + verify(currentNode.isCoordinator(), "current node is expected to be a coordinator"); + hostRequirement = ImmutableSet.of(currentNode.getHostAndPort()); + } + return new SingleDistributionSplitAssigner( + hostRequirement, + ImmutableSet.builder() + .addAll(partitionedSources) + .addAll(replicatedSources) + .build()); + } + if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION) || partitioning.equals(SOURCE_DISTRIBUTION)) { + return new ArbitraryDistributionSplitAssigner( + partitioning.getCatalogHandle(), + partitionedSources, + replicatedSources, + targetPartitionSizeInBytes, + standardSplitSizeInBytes, + maxArbitraryDistributionTaskSplitCount); + } + if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() || + (partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) { + return new HashDistributionSplitAssigner( + partitioning.getCatalogHandle(), + partitionedSources, + replicatedSources, + getFaultTolerantExecutionTargetTaskInputSize(session).toBytes(), + outputDataSizeEstimates, + sourcePartitioningScheme, + getFaultTolerantPreserveInputPartitionsInWriteStage(session) && isWriteFragment(fragment)); + } + + // other partitioning handles are not expected to be set as a fragment partitioning + throw new IllegalArgumentException("Unexpected partitioning: " + partitioning); + } + + private static boolean isWriteFragment(PlanFragment fragment) + { + PlanVisitor visitor = new PlanVisitor<>() + { + @Override + protected Boolean visitPlan(PlanNode node, Void context) + { + for (PlanNode child : node.getSources()) { + if (child.accept(this, context)) { + return true; + } + } + return false; + } + + @Override + public Boolean visitTableWriter(TableWriterNode node, Void context) + { + return true; + } + }; + + return fragment.getRoot().accept(visitor, null); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java index 3fa17a05d54a..b56cc12da7ed 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java @@ -65,15 +65,23 @@ public Optional getBucketToPartitionMap() public int getPartition(Split split) { - checkState(bucketToPartitionMap.isPresent(), "bucketToPartitionMap is expected to be present"); - checkState(splitToBucketFunction.isPresent(), "splitToBucketFunction is expected to be present"); - int bucket = splitToBucketFunction.get().applyAsInt(split); - checkState( - bucketToPartitionMap.get().length > bucket, - "invalid bucketToPartitionMap size (%s), bucket to partition mapping not found for bucket %s", - bucketToPartitionMap.get().length, - bucket); - return bucketToPartitionMap.get()[bucket]; + if (splitToBucketFunction.isPresent()) { + checkState(bucketToPartitionMap.isPresent(), "bucketToPartitionMap is expected to be present"); + int bucket = splitToBucketFunction.get().applyAsInt(split); + checkState( + bucketToPartitionMap.get().length > bucket, + "invalid bucketToPartitionMap size (%s), bucket to partition mapping not found for bucket %s", + bucketToPartitionMap.get().length, + bucket); + return bucketToPartitionMap.get()[bucket]; + } + checkState(partitionCount == 1, "partitionCount is expected to be set to 1: %s", partitionCount); + return 0; + } + + public boolean isExplicitPartitionToNodeMappingPresent() + { + return partitionToNodeMap.isPresent(); } public Optional getNodeRequirement(int partition) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java index 1da65d6a7df6..8f07c9b2f9bd 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java @@ -76,6 +76,10 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; +/** + * Deprecated in favor of {@link EventDrivenFaultTolerantQueryScheduler} + */ +@Deprecated public class FaultTolerantQueryScheduler implements QueryScheduler { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index b1699317ba94..42897802fd95 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -106,6 +106,10 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; +/** + * Deprecated in favor of {@link EventDrivenFaultTolerantQueryScheduler} + */ +@Deprecated public class FaultTolerantStageScheduler { private static final Logger log = Logger.get(FaultTolerantStageScheduler.class); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java new file mode 100644 index 000000000000..f536412d26ff --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import io.trino.connector.CatalogHandle; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +class HashDistributionSplitAssigner + implements SplitAssigner +{ + private final Optional catalogRequirement; + private final Set replicatedSources; + private final Set allSources; + private final FaultTolerantPartitioningScheme sourcePartitioningScheme; + private final Map outputPartitionToTaskPartition; + + private final Set createdTaskPartitions = new HashSet<>(); + private final Set completedSources = new HashSet<>(); + private final ListMultimap replicatedSplits = ArrayListMultimap.create(); + + private int nextTaskPartitionId; + + HashDistributionSplitAssigner( + Optional catalogRequirement, + Set partitionedSources, + Set replicatedSources, + long targetPartitionSizeInBytes, + Map outputDataSizeEstimates, + FaultTolerantPartitioningScheme sourcePartitioningScheme, + boolean preserveOutputPartitioning) + { + this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + this.replicatedSources = ImmutableSet.copyOf(requireNonNull(replicatedSources, "replicatedSources is null")); + allSources = ImmutableSet.builder() + .addAll(partitionedSources) + .addAll(replicatedSources) + .build(); + this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); + outputPartitionToTaskPartition = createOutputPartitionToTaskPartition( + sourcePartitioningScheme, + partitionedSources, + outputDataSizeEstimates, + preserveOutputPartitioning, + targetPartitionSizeInBytes); + } + + @Override + public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) + { + AssignmentResult.Builder assignment = AssignmentResult.builder(); + + if (replicatedSources.contains(planNodeId)) { + replicatedSplits.putAll(planNodeId, splits.values()); + for (Integer partitionId : createdTaskPartitions) { + assignment.updatePartition(new PartitionUpdate(partitionId, planNodeId, ImmutableList.copyOf(splits.values()), noMoreSplits)); + } + } + else { + for (Integer outputPartitionId : splits.keySet()) { + TaskPartition taskPartition = outputPartitionToTaskPartition.get(outputPartitionId); + verify(taskPartition != null, "taskPartition not found for outputPartitionId: %s", outputPartitionId); + if (!taskPartition.isIdAssigned()) { + // Assigns lazily to ensure task ids are incremental and with no gaps. + // Gaps can occur when scanning over a bucketed table as some buckets may contain no data. + taskPartition.assignId(nextTaskPartitionId++); + } + int taskPartitionId = taskPartition.getId(); + if (!createdTaskPartitions.contains(taskPartitionId)) { + Set hostRequirement = sourcePartitioningScheme.getNodeRequirement(outputPartitionId) + .map(InternalNode::getHostAndPort) + .map(ImmutableSet::of) + .orElse(ImmutableSet.of()); + assignment.addPartition(new Partition( + taskPartitionId, + new NodeRequirements(catalogRequirement, hostRequirement))); + for (PlanNodeId replicatedSource : replicatedSplits.keySet()) { + assignment.updatePartition(new PartitionUpdate(taskPartitionId, replicatedSource, replicatedSplits.get(replicatedSource), completedSources.contains(replicatedSource))); + } + for (PlanNodeId completedSource : completedSources) { + assignment.updatePartition(new PartitionUpdate(taskPartitionId, completedSource, ImmutableList.of(), true)); + } + createdTaskPartitions.add(taskPartitionId); + } + assignment.updatePartition(new PartitionUpdate(taskPartitionId, planNodeId, splits.get(outputPartitionId), false)); + } + } + + if (noMoreSplits) { + completedSources.add(planNodeId); + for (Integer taskPartition : createdTaskPartitions) { + assignment.updatePartition(new PartitionUpdate(taskPartition, planNodeId, ImmutableList.of(), true)); + } + if (completedSources.containsAll(allSources)) { + if (createdTaskPartitions.isEmpty()) { + assignment.addPartition(new Partition( + 0, + new NodeRequirements(catalogRequirement, ImmutableSet.of()))); + for (PlanNodeId replicatedSource : replicatedSplits.keySet()) { + assignment.updatePartition(new PartitionUpdate(0, replicatedSource, replicatedSplits.get(replicatedSource), true)); + } + for (PlanNodeId completedSource : completedSources) { + assignment.updatePartition(new PartitionUpdate(0, completedSource, ImmutableList.of(), true)); + } + createdTaskPartitions.add(0); + } + for (Integer taskPartition : createdTaskPartitions) { + assignment.sealPartition(taskPartition); + } + assignment.setNoMorePartitions(); + replicatedSplits.clear(); + } + } + + return assignment.build(); + } + + @Override + public AssignmentResult finish() + { + checkState(!createdTaskPartitions.isEmpty(), "createdTaskPartitions is not expected to be empty"); + return AssignmentResult.builder().build(); + } + + private static Map createOutputPartitionToTaskPartition( + FaultTolerantPartitioningScheme sourcePartitioningScheme, + Set partitionedSources, + Map outputDataSizeEstimates, + boolean preserveOutputPartitioning, + long targetPartitionSizeInBytes) + { + int partitionCount = sourcePartitioningScheme.getPartitionCount(); + if (sourcePartitioningScheme.isExplicitPartitionToNodeMappingPresent() || + partitionedSources.isEmpty() || + !outputDataSizeEstimates.keySet().containsAll(partitionedSources) || + preserveOutputPartitioning) { + // if bucket scheme is set explicitly or if estimates are missing create one task partition per output partition + return IntStream.range(0, partitionCount) + .boxed() + .collect(toImmutableMap(Function.identity(), (key) -> new TaskPartition())); + } + + List partitionedSourcesEstimates = outputDataSizeEstimates.entrySet().stream() + .filter(entry -> partitionedSources.contains(entry.getKey())) + .map(Map.Entry::getValue) + .collect(toImmutableList()); + OutputDataSizeEstimate mergedEstimate = OutputDataSizeEstimate.merge(partitionedSourcesEstimates); + ImmutableMap.Builder result = ImmutableMap.builder(); + PriorityQueue assignments = new PriorityQueue<>(); + assignments.add(new PartitionAssignment(new TaskPartition(), 0)); + for (int outputPartitionId = 0; outputPartitionId < partitionCount; outputPartitionId++) { + long outputPartitionSize = mergedEstimate.getPartitionSizeInBytes(outputPartitionId); + if (assignments.peek().assignedDataSizeInBytes() + outputPartitionSize > targetPartitionSizeInBytes + && assignments.size() < partitionCount) { + assignments.add(new PartitionAssignment(new TaskPartition(), 0)); + } + PartitionAssignment assignment = assignments.poll(); + result.put(outputPartitionId, assignment.taskPartition()); + assignments.add(new PartitionAssignment(assignment.taskPartition(), assignment.assignedDataSizeInBytes() + outputPartitionSize)); + } + return result.buildOrThrow(); + } + + private record PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes) + implements Comparable + { + public PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes) + { + this.taskPartition = requireNonNull(taskPartition, "taskPartition is null"); + this.assignedDataSizeInBytes = assignedDataSizeInBytes; + } + + @Override + public int compareTo(PartitionAssignment other) + { + return Long.compare(assignedDataSizeInBytes, other.assignedDataSizeInBytes); + } + } + + private static class TaskPartition + { + private OptionalInt id = OptionalInt.empty(); + + public void assignId(int id) + { + this.id = OptionalInt.of(id); + } + + public boolean isIdAssigned() + { + return id.isPresent(); + } + + public int getId() + { + checkState(id.isPresent(), "id is expected to be assigned"); + return id.getAsInt(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputDataSizeEstimate.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputDataSizeEstimate.java new file mode 100644 index 000000000000..b1fd57610faa --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputDataSizeEstimate.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.primitives.ImmutableLongArray; + +import java.util.Collection; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class OutputDataSizeEstimate +{ + private final ImmutableLongArray partitionDataSizes; + + public OutputDataSizeEstimate(ImmutableLongArray partitionDataSizes) + { + this.partitionDataSizes = requireNonNull(partitionDataSizes, "partitionDataSizes is null"); + } + + public long getPartitionSizeInBytes(int partitionId) + { + return partitionDataSizes.get(partitionId); + } + + public static OutputDataSizeEstimate merge(Collection estimates) + { + int partitionCount = getPartitionCount(estimates); + long[] merged = new long[partitionCount]; + for (OutputDataSizeEstimate estimate : estimates) { + for (int partitionId = 0; partitionId < partitionCount; partitionId++) { + merged[partitionId] += estimate.getPartitionSizeInBytes(partitionId); + } + } + return new OutputDataSizeEstimate(ImmutableLongArray.copyOf(merged)); + } + + private static int getPartitionCount(Collection estimates) + { + int[] partitionCounts = estimates.stream() + .mapToInt(estimate -> estimate.partitionDataSizes.length()) + .distinct() + .toArray(); + checkArgument(partitionCounts.length <= 1, "partition count is expected to match"); + if (partitionCounts.length == 0) { + return 0; + } + return partitionCounts[0]; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SingleDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SingleDistributionSplitAssigner.java new file mode 100644 index 000000000000..c58e7e9f40de --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SingleDistributionSplitAssigner.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +class SingleDistributionSplitAssigner + implements SplitAssigner +{ + private final Set hostRequirement; + private final Set allSources; + + private boolean partitionAdded; + private final Set completedSources = new HashSet<>(); + + SingleDistributionSplitAssigner(Set hostRequirement, Set allSources) + { + this.hostRequirement = ImmutableSet.copyOf(requireNonNull(hostRequirement, "hostRequirement is null")); + this.allSources = ImmutableSet.copyOf(requireNonNull(allSources, "allSources is null")); + } + + @Override + public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) + { + AssignmentResult.Builder assignment = AssignmentResult.builder(); + if (!partitionAdded) { + partitionAdded = true; + assignment.addPartition(new Partition(0, new NodeRequirements(Optional.empty(), hostRequirement))); + assignment.setNoMorePartitions(); + } + if (!splits.isEmpty()) { + checkState(!completedSources.contains(planNodeId), "source is finished: %s", planNodeId); + assignment.updatePartition(new PartitionUpdate( + 0, + planNodeId, + ImmutableList.copyOf(splits.values()), + false)); + } + if (noMoreSplits) { + assignment.updatePartition(new PartitionUpdate( + 0, + planNodeId, + ImmutableList.of(), + true)); + completedSources.add(planNodeId); + } + if (completedSources.containsAll(allSources)) { + assignment.sealPartition(0); + } + return assignment.build(); + } + + @Override + public AssignmentResult finish() + { + AssignmentResult.Builder result = AssignmentResult.builder(); + if (!partitionAdded) { + partitionAdded = true; + result + .addPartition(new Partition(0, new NodeRequirements(Optional.empty(), hostRequirement))) + .sealPartition(0) + .setNoMorePartitions(); + } + return result.build(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SplitAssigner.java new file mode 100644 index 000000000000..143419dcbdf3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SplitAssigner.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ListMultimap; +import com.google.common.primitives.ImmutableIntArray; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.metadata.Split; +import io.trino.sql.planner.plan.PlanNodeId; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * An implementation is not required to be thread safe + */ +@NotThreadSafe +interface SplitAssigner +{ + AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits); + + AssignmentResult finish(); + + record AssignmentResult( + List partitionsAdded, + boolean noMorePartitions, + List partitionUpdates, + ImmutableIntArray sealedPartitions) + { + public AssignmentResult + { + partitionsAdded = ImmutableList.copyOf(requireNonNull(partitionsAdded, "partitionsAdded is null")); + partitionUpdates = ImmutableList.copyOf(requireNonNull(partitionUpdates, "partitionUpdates is null")); + } + + public void update(EventDrivenTaskSource.Callback callback) + { + if (!partitionsAdded.isEmpty()) { + callback.partitionsAdded(partitionsAdded); + } + if (noMorePartitions) { + callback.noMorePartitions(); + } + if (!partitionUpdates.isEmpty()) { + callback.partitionsUpdated(partitionUpdates); + } + if (!sealedPartitions.isEmpty()) { + callback.partitionsSealed(sealedPartitions); + } + } + + public static AssignmentResult.Builder builder() + { + return new AssignmentResult.Builder(); + } + + public static class Builder + { + private final ImmutableList.Builder partitionsAdded = ImmutableList.builder(); + private boolean noMorePartitions; + private final ImmutableList.Builder partitionUpdates = ImmutableList.builder(); + private final ImmutableIntArray.Builder sealedPartitions = ImmutableIntArray.builder(); + + @CanIgnoreReturnValue + public AssignmentResult.Builder addPartition(Partition partition) + { + partitionsAdded.add(partition); + return this; + } + + @CanIgnoreReturnValue + public AssignmentResult.Builder setNoMorePartitions() + { + this.noMorePartitions = true; + return this; + } + + @CanIgnoreReturnValue + public AssignmentResult.Builder updatePartition(PartitionUpdate partitionUpdate) + { + partitionUpdates.add(partitionUpdate); + return this; + } + + @CanIgnoreReturnValue + public AssignmentResult.Builder sealPartition(int partitionId) + { + sealedPartitions.add(partitionId); + return this; + } + + public AssignmentResult build() + { + return new AssignmentResult( + partitionsAdded.build(), + noMorePartitions, + partitionUpdates.build(), + sealedPartitions.build()); + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java index f539bdc6963d..3affd4c3367b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java @@ -27,17 +27,11 @@ import io.trino.execution.StageInfo; import io.trino.execution.TableInfo; import io.trino.execution.TaskId; -import io.trino.metadata.CatalogInfo; import io.trino.metadata.Metadata; -import io.trino.metadata.TableProperties; -import io.trino.metadata.TableSchema; import io.trino.spi.QueryId; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.TableScanNode; import java.util.List; import java.util.Map; @@ -51,7 +45,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; import static io.trino.execution.SqlStage.createSqlStage; -import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static java.lang.Integer.parseInt; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -91,7 +84,7 @@ static StageManager create( SqlStage stage = createSqlStage( getStageId(session.getQueryId(), fragment.getId()), fragment, - extractTableInfo(session, metadata, fragment), + TableInfo.extract(session, metadata, fragment), taskFactory, session, summarizeTaskInfo, @@ -129,27 +122,6 @@ static StageManager create( return stageManager; } - private static Map extractTableInfo(Session session, Metadata metadata, PlanFragment fragment) - { - return searchFrom(fragment.getRoot()) - .where(TableScanNode.class::isInstance) - .findAll() - .stream() - .map(TableScanNode.class::cast) - .collect(toImmutableMap(PlanNode::getId, node -> getTableInfo(session, metadata, node))); - } - - private static TableInfo getTableInfo(Session session, Metadata metadata, TableScanNode node) - { - TableSchema tableSchema = metadata.getTableSchema(session, node.getTable()); - TableProperties tableProperties = metadata.getTableProperties(session, node.getTable()); - Optional connectorName = metadata.listCatalogs(session).stream() - .filter(catalogInfo -> catalogInfo.getCatalogName().equals(tableSchema.getCatalogName())) - .map(CatalogInfo::getConnectorName) - .findFirst(); - return new TableInfo(connectorName, tableSchema.getQualifiedName(), tableProperties.getPredicate()); - } - private static StageId getStageId(QueryId queryId, PlanFragmentId fragmentId) { // TODO: refactor fragment id to be based on an integer diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index ab9e2eadf5e3..c80e12f77d0f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -102,6 +102,10 @@ import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static java.util.Objects.requireNonNull; +/** + * Deprecated in favor of {@link EventDrivenTaskSourceFactory} + */ +@Deprecated public class StageTaskSourceFactory implements TaskSourceFactory { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java index 85bd6e238da1..34661234f083 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java @@ -21,6 +21,10 @@ import java.util.function.LongConsumer; +/** + * Deprecated in favor of {@link EventDrivenTaskSourceFactory} + */ +@Deprecated public interface TaskSourceFactory { TaskSource create( diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index e70b2028b5ec..610d3b09a11c 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -63,6 +63,7 @@ import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.execution.scheduler.BinPackingNodeAllocatorService; import io.trino.execution.scheduler.ConstantPartitionMemoryEstimator; +import io.trino.execution.scheduler.EventDrivenTaskSourceFactory; import io.trino.execution.scheduler.FixedCountNodeAllocatorService; import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeSchedulerConfig; @@ -326,6 +327,7 @@ protected void setup(Binder binder) newExporter(binder).export(SplitSchedulerStats.class).withGeneratedName(); binder.bind(TaskSourceFactory.class).to(StageTaskSourceFactory.class).in(Scopes.SINGLETON); + binder.bind(EventDrivenTaskSourceFactory.class).in(Scopes.SINGLETON); binder.bind(TaskDescriptorStorage.class).in(Scopes.SINGLETON); newExporter(binder).export(TaskDescriptorStorage.class).withGeneratedName(); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index 54febc8d5bd7..bcdefa55a00a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -77,7 +77,8 @@ public void testDefaults() .setFaultTolerantExecutionMaxTaskSplitCount(256) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15))) .setFaultTolerantExecutionPartitionCount(50) - .setFaultTolerantPreserveInputPartitionsInWriteStage(true)); + .setFaultTolerantPreserveInputPartitionsInWriteStage(true) + .setFaultTolerantExecutionEventDrivenSchedulerEnabled(true)); } @Test @@ -123,6 +124,7 @@ public void testExplicitPropertyMappings() .put("fault-tolerant-execution-task-descriptor-storage-max-memory", "3GB") .put("fault-tolerant-execution-partition-count", "123") .put("fault-tolerant-execution-preserve-input-partitions-in-write-stage", "false") + .put("experimental.fault-tolerant-execution-event-driven-scheduler-enabled", "false") .buildOrThrow(); QueryManagerConfig expected = new QueryManagerConfig() @@ -164,7 +166,8 @@ public void testExplicitPropertyMappings() .setFaultTolerantExecutionMaxTaskSplitCount(22) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.of(3, GIGABYTE)) .setFaultTolerantExecutionPartitionCount(123) - .setFaultTolerantPreserveInputPartitionsInWriteStage(false); + .setFaultTolerantPreserveInputPartitionsInWriteStage(false) + .setFaultTolerantExecutionEventDrivenSchedulerEnabled(false); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java index 6205ec6505f4..41134ae39587 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java @@ -22,7 +22,11 @@ import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.memory.context.LocalMemoryContext; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; import io.trino.spi.QueryId; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.exchange.ExchangeSink; import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import org.testng.annotations.Test; @@ -38,6 +42,7 @@ import static io.trino.execution.buffer.BufferState.FINISHED; import static io.trino.execution.buffer.BufferState.FLUSHING; import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; +import static io.trino.spi.type.VarcharType.VARCHAR; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -224,13 +229,13 @@ public void testEnqueueAfterFinish() OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink, 2); assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page1"))); - outputBuffer.enqueue(1, ImmutableList.of(utf8Slice("page2"), utf8Slice("page3"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page1"))); + outputBuffer.enqueue(1, ImmutableList.of(createPage("page2"), createPage("page3"))); ImmutableListMultimap expectedDataBufferState = ImmutableListMultimap.builder() - .put(0, utf8Slice("page1")) - .put(1, utf8Slice("page2")) - .put(1, utf8Slice("page3")) + .put(0, createPage("page1")) + .put(1, createPage("page2")) + .put(1, createPage("page3")) .build(); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); @@ -238,12 +243,12 @@ public void testEnqueueAfterFinish() outputBuffer.setNoMorePages(); assertEquals(outputBuffer.getState(), FLUSHING); // the buffer is flushing, this page is expected to be rejected - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page4"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page4"))); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); finish.complete(null); assertEquals(outputBuffer.getState(), FINISHED); - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page5"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page5"))); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); } @@ -257,13 +262,13 @@ public void testEnqueueAfterAbort() OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink, 2); assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page1"))); - outputBuffer.enqueue(1, ImmutableList.of(utf8Slice("page2"), utf8Slice("page3"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page1"))); + outputBuffer.enqueue(1, ImmutableList.of(createPage("page2"), createPage("page3"))); ImmutableListMultimap expectedDataBufferState = ImmutableListMultimap.builder() - .put(0, utf8Slice("page1")) - .put(1, utf8Slice("page2")) - .put(1, utf8Slice("page3")) + .put(0, createPage("page1")) + .put(1, createPage("page2")) + .put(1, createPage("page3")) .build(); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); @@ -271,12 +276,12 @@ public void testEnqueueAfterAbort() outputBuffer.abort(); assertEquals(outputBuffer.getState(), ABORTED); // the buffer is flushing, this page is expected to be rejected - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page4"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page4"))); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); abort.complete(null); assertEquals(outputBuffer.getState(), ABORTED); - outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page5"))); + outputBuffer.enqueue(0, ImmutableList.of(createPage("page5"))); assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); } @@ -299,6 +304,19 @@ private static void assertBlocked(ListenableFuture blocked) assertFalse(blocked.isDone()); } + private static Slice createPage(String value) + { + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(VARCHAR)); + pageBuilder.declarePosition(); + Slice valueSlice = utf8Slice(value); + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + blockBuilder.writeBytes(valueSlice, 0, valueSlice.length()); + blockBuilder.closeEntry(); + Page page = pageBuilder.build(); + PagesSerde pagesSerde = new PagesSerdeFactory(new TestingBlockEncodingSerde(), false).createPagesSerde(); + return pagesSerde.serialize(pagesSerde.newContext(), page); + } + private static class TestingExchangeSink implements ExchangeSink { diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java new file mode 100644 index 000000000000..dc03725e1cad --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java @@ -0,0 +1,707 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSetMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; +import com.google.common.collect.Sets; +import io.trino.connector.CatalogHandle; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.sql.planner.plan.PlanNodeId; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Collections.shuffle; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestArbitraryDistributionSplitAssigner +{ + private static final int FUZZ_TESTING_INVOCATION_COUNT = 100; + + private static final CatalogHandle TESTING_CATALOG_HANDLE = CatalogHandle.createRootCatalogHandle("testing"); + private static final int STANDARD_SPLIT_SIZE_IN_BYTES = 1; + + private static final PlanNodeId PARTITIONED_1 = new PlanNodeId("partitioned-1"); + private static final PlanNodeId PARTITIONED_2 = new PlanNodeId("partitioned-2"); + private static final PlanNodeId REPLICATED_1 = new PlanNodeId("replicated-1"); + private static final PlanNodeId REPLICATED_2 = new PlanNodeId("replicated-2"); + + private static final HostAddress HOST_1 = HostAddress.fromParts("localhost", 8081); + private static final HostAddress HOST_2 = HostAddress.fromParts("localhost", 8082); + private static final HostAddress HOST_3 = HostAddress.fromParts("localhost", 8083); + + @Test + public void testEmpty() + { + // single partitioned source + SplitAssigner splitAssigner = createSplitAssigner(ImmutableSet.of(PARTITIONED_1), ImmutableSet.of(), 100, false); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + splitAssigner.assign(PARTITIONED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + List taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(1); + assertTaskDescriptor(taskDescriptors.get(0), 0, ImmutableListMultimap.of()); + + // single replicated source + splitAssigner = createSplitAssigner(ImmutableSet.of(), ImmutableSet.of(REPLICATED_1), 100, false); + callback = new TestingTaskSourceCallback(); + splitAssigner.assign(REPLICATED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(1); + assertTaskDescriptor(taskDescriptors.get(0), 0, ImmutableListMultimap.of()); + + // partitioned and replicates source + splitAssigner = createSplitAssigner(ImmutableSet.of(PARTITIONED_1), ImmutableSet.of(REPLICATED_1), 100, true); + callback = new TestingTaskSourceCallback(); + splitAssigner.assign(REPLICATED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.assign(PARTITIONED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(1); + assertTaskDescriptor(taskDescriptors.get(0), 0, ImmutableListMultimap.of()); + + splitAssigner = createSplitAssigner(ImmutableSet.of(PARTITIONED_1), ImmutableSet.of(REPLICATED_1), 100, true); + callback = new TestingTaskSourceCallback(); + splitAssigner.assign(PARTITIONED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.assign(REPLICATED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(1); + assertTaskDescriptor(taskDescriptors.get(0), 0, ImmutableListMultimap.of()); + + splitAssigner = createSplitAssigner(ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), ImmutableSet.of(REPLICATED_1, REPLICATED_2), 100, true); + callback = new TestingTaskSourceCallback(); + splitAssigner.assign(REPLICATED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.assign(PARTITIONED_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.assign(PARTITIONED_2, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.assign(REPLICATED_2, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(1); + assertTaskDescriptor(taskDescriptors.get(0), 0, ImmutableListMultimap.of()); + } + + @Test + public void testNoHostRequirement() + { + // no splits + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of(new SplitBatch(PARTITIONED_1, ImmutableList.of(), true)), + 1, + false); + + // single partitioned source + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true)), + // one split per partition + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1), createSplit(2)), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1), createSplit(2)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(3), createSplit(4)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5)), true)), + // two splits per partition + 2, + true); + + // multiple partitioned sources + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(2)), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(2), createSplit(3)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(6)), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(1)), true), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(2), createSplit(3)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(4)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(6)), true)), + 2, + false); + + // single replicated source + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(3)), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(1)), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(3)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4)), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(5)), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(2)), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(3)), true)), + 2, + false); + + // multiple replicates sources + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(3)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4)), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(3)), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(4)), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1)), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2), createSplit(3), createSplit(4), createSplit(5)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(6), createSplit(7)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(8), createSplit(9)), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(10), createSplit(11), createSplit(12), createSplit(13)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(14), createSplit(15)), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(16), createSplit(17), createSplit(18), createSplit(19)), true), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(20)), false), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(21), createSplit(22)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(23)), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(24), createSplit(25)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(26), createSplit(27)), true)), + 3, + true); + } + + @Test + public void testWithHostRequirement() + { + // single partitioned source + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1))), true)), + // one split per partition + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1, HOST_2))), true)), + // one split per partition + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1, HOST_2)), createSplit(2, ImmutableList.of(HOST_2))), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1, HOST_2))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(2, ImmutableList.of(HOST_1))), true)), + // two splits per partition + 2, + false); + + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1, HOST_2)), createSplit(2, ImmutableList.of(HOST_1, HOST_2))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(3, ImmutableList.of(HOST_3)), createSplit(4, ImmutableList.of(HOST_1))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5)), true)), + // two splits per partition + 2, + true); + + // multiple partitioned sources + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_3))), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(2, ImmutableList.of(HOST_3))), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_3))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(2, ImmutableList.of(HOST_3)), createSplit(3, ImmutableList.of(HOST_2))), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4, ImmutableList.of(HOST_1))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(6, ImmutableList.of(HOST_3))), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1, HOST_2))), true), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(2), createSplit(3, ImmutableList.of(HOST_3))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(4, ImmutableList.of(HOST_1, HOST_2))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(5, ImmutableList.of(HOST_3))), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(6, ImmutableList.of(HOST_1, HOST_2))), true)), + 2, + false); + + // single replicated source + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_3))), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(3, ImmutableList.of(HOST_2))), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(1)), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(3, ImmutableList.of(HOST_3))), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4, ImmutableList.of(HOST_3))), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(5)), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1))), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(2, ImmutableList.of(HOST_2))), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(3)), true)), + 2, + true); + + // multiple replicates sources + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_1))), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(3)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(4, ImmutableList.of(HOST_1))), true)), + 1, + false); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_2))), true), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(3, ImmutableList.of(HOST_2))), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(4)), true)), + 1, + true); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1, ImmutableList.of(HOST_2))), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(2), createSplit(3), createSplit(4), createSplit(5)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(6, ImmutableList.of(HOST_2, HOST_3)), createSplit(7)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(8), createSplit(9, ImmutableList.of(HOST_2, HOST_3))), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(10), createSplit(11), createSplit(12), createSplit(13)), false), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(14), createSplit(15, ImmutableList.of(HOST_1, HOST_3))), false), + new SplitBatch(REPLICATED_1, ImmutableList.of(createSplit(16), createSplit(17), createSplit(18), createSplit(19)), true), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(20, ImmutableList.of(HOST_1, HOST_3))), false), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(21), createSplit(22)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(23, ImmutableList.of(HOST_2, HOST_3))), true), + new SplitBatch(REPLICATED_2, ImmutableList.of(createSplit(24), createSplit(25)), true), + new SplitBatch(PARTITIONED_2, ImmutableList.of(createSplit(26), createSplit(27, ImmutableList.of(HOST_1, HOST_3))), true)), + 3, + false); + } + + @Test + public void fuzzTestingNoHostRequirement() + { + for (int i = 0; i < FUZZ_TESTING_INVOCATION_COUNT; i++) { + fuzzTesting(false); + } + } + + @Test + public void fuzzTestingWithHostRequirement() + { + for (int i = 0; i < FUZZ_TESTING_INVOCATION_COUNT; i++) { + fuzzTesting(true); + } + } + + private void fuzzTesting(boolean withHostRequirements) + { + Set partitionedSources = new HashSet<>(); + Set replicatedSources = new HashSet<>(); + partitionedSources.add(PARTITIONED_1); + if (ThreadLocalRandom.current().nextBoolean()) { + partitionedSources.add(PARTITIONED_2); + } + if (ThreadLocalRandom.current().nextDouble() > 0.2) { + replicatedSources.add(REPLICATED_1); + } + if (ThreadLocalRandom.current().nextDouble() > 0.5) { + replicatedSources.add(REPLICATED_2); + } + Set allSources = ImmutableSet.builder() + .addAll(partitionedSources) + .addAll(replicatedSources) + .build(); + + List batches = new ArrayList<>(); + Map splitCount = allSources.stream() + .collect(Collectors.toMap(Function.identity(), planNodeId -> ThreadLocalRandom.current().nextInt(100))); + + AtomicInteger nextSplitId = new AtomicInteger(); + while (!splitCount.isEmpty()) { + List remainingSources = ImmutableList.copyOf(splitCount.keySet()); + PlanNodeId source = remainingSources.get(ThreadLocalRandom.current().nextInt(remainingSources.size())); + int batchSize = ThreadLocalRandom.current().nextInt(5); + int remaining = splitCount.compute(source, (key, value) -> value - batchSize); + if (remaining <= 0) { + splitCount.remove(source); + } + List splits = IntStream.range(0, batchSize) + .mapToObj(value -> generateSplit(nextSplitId, replicatedSources.contains(source), withHostRequirements)) + .collect(toImmutableList()); + batches.add(new SplitBatch(source, splits, remaining <= 0)); + } + + int splitsPerPartition = ThreadLocalRandom.current().nextInt(3); + testAssigner(partitionedSources, replicatedSources, batches, splitsPerPartition, ThreadLocalRandom.current().nextBoolean()); + } + + private Split generateSplit(AtomicInteger nextSplitId, boolean replicated, boolean withHostRequirements) + { + if (replicated || !withHostRequirements || ThreadLocalRandom.current().nextDouble() > 0.5) { + return createSplit(nextSplitId.getAndIncrement()); + } + List allHosts = new ArrayList<>(); + allHosts.add(HOST_1); + allHosts.add(HOST_2); + allHosts.add(HOST_3); + shuffle(allHosts); + List addresses = ImmutableList.copyOf(allHosts.subList(0, ThreadLocalRandom.current().nextInt(1, allHosts.size()))); + return createSplit(nextSplitId.getAndIncrement(), addresses); + } + + private static void testAssigner( + Set partitionedSources, + Set replicatedSources, + List batches, + int partitionedSplitsPerPartition, + boolean verifyMaxTaskSplitCount) + { + SplitAssigner splitAssigner = createSplitAssigner(partitionedSources, replicatedSources, partitionedSplitsPerPartition, verifyMaxTaskSplitCount); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + ListMultimap expectedReplicatedSplits = ArrayListMultimap.create(); + Map> expectedPartitionedSplits = new HashMap<>(); + Set finishedReplicatedSources = new HashSet<>(); + Map, PartitionAssignment> currentSplitAssignments = new HashMap<>(); + AtomicInteger nextPartitionId = new AtomicInteger(); + for (SplitBatch batch : batches) { + PlanNodeId planNodeId = batch.getPlanNodeId(); + List splits = batch.getSplits(); + boolean noMoreSplits = batch.isNoMoreSplits(); + boolean replicated = replicatedSources.contains(planNodeId); + if (replicated) { + expectedReplicatedSplits.putAll(planNodeId, splits); + if (noMoreSplits) { + finishedReplicatedSources.add(planNodeId); + } + } + else { + for (Split split : splits) { + Optional hostRequirement = Optional.empty(); + if (!split.isRemotelyAccessible()) { + int splitCount = Integer.MAX_VALUE; + for (HostAddress hostAddress : split.getConnectorSplit().getAddresses()) { + PartitionAssignment currentAssignment = currentSplitAssignments.get(Optional.of(hostAddress)); + if (currentAssignment == null) { + hostRequirement = Optional.of(hostAddress); + break; + } + else if (currentAssignment.getSplits().size() < splitCount) { + splitCount = currentAssignment.getSplits().size(); + hostRequirement = Optional.of(hostAddress); + } + } + } + PartitionAssignment currentAssignment = currentSplitAssignments.get(hostRequirement); + if (currentAssignment != null && currentAssignment.getSplits().size() + 1 > partitionedSplitsPerPartition) { + expectedPartitionedSplits.computeIfAbsent(currentAssignment.getPartitionId(), key -> ArrayListMultimap.create()).putAll(currentAssignment.getSplits()); + currentSplitAssignments.remove(hostRequirement); + } + currentSplitAssignments + .computeIfAbsent(hostRequirement, key -> new PartitionAssignment(nextPartitionId.getAndIncrement())) + .getSplits() + .put(planNodeId, split); + } + } + splitAssigner.assign(planNodeId, createSplitsMultimap(splits), noMoreSplits).update(callback); + callback.checkContainsSplits(planNodeId, splits, replicated); + + if (finishedReplicatedSources.containsAll(replicatedSources)) { + Set openAssignments = currentSplitAssignments.values().stream() + .map(PartitionAssignment::getPartitionId) + .collect(toImmutableSet()); + for (int partitionId = 0; partitionId < nextPartitionId.get(); partitionId++) { + if (!openAssignments.contains(partitionId)) { + assertTrue(callback.isSealed(partitionId)); + } + } + } + } + splitAssigner.finish().update(callback); + for (PartitionAssignment assignment : currentSplitAssignments.values()) { + expectedPartitionedSplits.computeIfAbsent(assignment.getPartitionId(), key -> ArrayListMultimap.create()).putAll(assignment.getSplits()); + } + List taskDescriptors = callback.getTaskDescriptors(); + int expectedPartitionCount = nextPartitionId.get(); + if (expectedPartitionCount == 0) { + // a single partition is always created + assertThat(taskDescriptors).hasSize(1); + TaskDescriptor taskDescriptor = taskDescriptors.get(0); + assertTaskDescriptor( + taskDescriptor, + taskDescriptor.getPartitionId(), + ImmutableListMultimap.builder() + .putAll(expectedReplicatedSplits) + .build()); + } + else { + assertThat(taskDescriptors).hasSize(expectedPartitionCount); + for (TaskDescriptor taskDescriptor : taskDescriptors) { + assertTaskDescriptor( + taskDescriptor, + taskDescriptor.getPartitionId(), + ImmutableListMultimap.builder() + .putAll(expectedReplicatedSplits) + .putAll(expectedPartitionedSplits.getOrDefault(taskDescriptor.getPartitionId(), ImmutableListMultimap.of())) + .build()); + } + } + } + + private static Split createSplit(int id) + { + return new Split(TESTING_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.empty())); + } + + private static Split createSplit(int id, List addresses) + { + return new Split(TESTING_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.of(addresses))); + } + + private static ListMultimap createSplitsMultimap(List splits) + { + int nextPartitionId = 0; + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (Split split : splits) { + result.put(nextPartitionId++, split); + } + return result.build(); + } + + private static void assertTaskDescriptor( + TaskDescriptor taskDescriptor, + int expectedPartitionId, + ListMultimap expectedSplits) + { + assertEquals(taskDescriptor.getPartitionId(), expectedPartitionId); + assertSplitsEqual(taskDescriptor.getSplits(), expectedSplits); + Set hostRequirement = null; + for (Split split : taskDescriptor.getSplits().values()) { + if (!split.isRemotelyAccessible()) { + if (hostRequirement == null) { + hostRequirement = ImmutableSet.copyOf(split.getAddresses()); + } + else { + hostRequirement = Sets.intersection(hostRequirement, ImmutableSet.copyOf(split.getAddresses())); + } + } + } + assertEquals(taskDescriptor.getNodeRequirements().getCatalogHandle(), Optional.of(TESTING_CATALOG_HANDLE)); + assertThat(taskDescriptor.getNodeRequirements().getAddresses()).containsAnyElementsOf(hostRequirement == null ? ImmutableSet.of() : hostRequirement); + } + + private static void assertSplitsEqual(ListMultimap actual, ListMultimap expected) + { + SetMultimap actualSplitIds = ImmutableSetMultimap.copyOf(Multimaps.transformValues(actual, TestingConnectorSplit::getSplitId)); + SetMultimap expectedSplitIds = ImmutableSetMultimap.copyOf(Multimaps.transformValues(expected, TestingConnectorSplit::getSplitId)); + assertEquals(actualSplitIds, expectedSplitIds); + } + + private static ArbitraryDistributionSplitAssigner createSplitAssigner( + Set partitionedSources, + Set replicatedSources, + int partitionedSplitsPerPartition, + boolean verifyMaxTaskSplitCount) + { + long targetPartitionSizeInBytes = Long.MAX_VALUE; + int maxTaskSplitCount = Integer.MAX_VALUE; + // make sure both limits are tested + if (verifyMaxTaskSplitCount) { + maxTaskSplitCount = partitionedSplitsPerPartition; + } + else { + targetPartitionSizeInBytes = STANDARD_SPLIT_SIZE_IN_BYTES * partitionedSplitsPerPartition; + } + return new ArbitraryDistributionSplitAssigner( + Optional.of(TESTING_CATALOG_HANDLE), + partitionedSources, + replicatedSources, + targetPartitionSizeInBytes, + STANDARD_SPLIT_SIZE_IN_BYTES, + maxTaskSplitCount); + } + + private static class SplitBatch + { + private final PlanNodeId planNodeId; + private final List splits; + private final boolean noMoreSplits; + + public SplitBatch(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + { + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + this.noMoreSplits = noMoreSplits; + } + + public PlanNodeId getPlanNodeId() + { + return planNodeId; + } + + public List getSplits() + { + return splits; + } + + public boolean isNoMoreSplits() + { + return noMoreSplits; + } + } + + private static class PartitionAssignment + { + private final int partitionId; + private final ListMultimap splits = ArrayListMultimap.create(); + + private PartitionAssignment(int partitionId) + { + this.partitionId = partitionId; + } + + public int getPartitionId() + { + return partitionId; + } + + public ListMultimap getSplits() + { + return splits; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestEventDrivenTaskSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestEventDrivenTaskSource.java new file mode 100644 index 000000000000..f9a021b2daa7 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestEventDrivenTaskSource.java @@ -0,0 +1,1031 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; +import com.google.common.primitives.ImmutableIntArray; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import com.google.common.util.concurrent.SettableFuture; +import com.google.common.util.concurrent.UncheckedTimeoutException; +import io.trino.connector.CatalogHandle; +import io.trino.exchange.SpoolingExchangeInput; +import io.trino.execution.scheduler.EventDrivenTaskSource.Partition; +import io.trino.execution.scheduler.EventDrivenTaskSource.PartitionUpdate; +import io.trino.metadata.Split; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeId; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceHandleSource; +import io.trino.split.RemoteSplit; +import io.trino.split.SplitSource; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; +import static java.lang.Math.max; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +public class TestEventDrivenTaskSource +{ + private static final int INVOCATION_COUNT = 20; + + private static final CatalogHandle TESTING_CATALOG_HANDLE = CatalogHandle.createRootCatalogHandle("testing"); + + private static final PlanNodeId PLAN_NODE_1 = new PlanNodeId("plan-node-1"); + private static final PlanNodeId PLAN_NODE_2 = new PlanNodeId("plan-node-2"); + private static final PlanNodeId PLAN_NODE_3 = new PlanNodeId("plan-node-3"); + private static final PlanNodeId PLAN_NODE_4 = new PlanNodeId("plan-node-3"); + + private static final PlanFragmentId FRAGMENT_1 = new PlanFragmentId("fragment-1"); + private static final PlanFragmentId FRAGMENT_2 = new PlanFragmentId("fragment-2"); + private static final PlanFragmentId FRAGMENT_3 = new PlanFragmentId("fragment-3"); + + private final AtomicInteger nextId = new AtomicInteger(); + + private ListeningScheduledExecutorService executor; + + @BeforeClass + public void setUp() + { + executor = listeningDecorator(newScheduledThreadPool(10, daemonThreadsNamed("dispatcher-query-%s"))); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + if (executor != null) { + executor.shutdownNow(); + executor = null; + } + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testHappyPath() + { + // no inputs + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.of()); + // single split + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, createSplit(0))); + // multiple splits + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_1, createSplit(0), createSplit(0), createSplit(1)) + .build()); + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_1, createSplit(0)) + .putAll(PLAN_NODE_2, createSplit(0)) + .build()); + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_1, createSplit(0)) + .putAll(PLAN_NODE_2, createSplit(0), createSplit(1)) + .build()); + testStageTaskSourceSuccess( + ImmutableListMultimap.of(), + ImmutableMap.of(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_1, createSplit(0), createSplit(3), createSplit(4)) + .putAll(PLAN_NODE_2, createSplit(0), createSplit(1)) + .build()); + // single source handle + testStageTaskSourceSuccess( + ImmutableListMultimap.of(FRAGMENT_1, createSourceHandle(1)), + ImmutableMap.of(FRAGMENT_1, PLAN_NODE_1), + ImmutableListMultimap.of()); + // multiple source handles + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .build(), + ImmutableMap.of(FRAGMENT_1, PLAN_NODE_1), + ImmutableListMultimap.of()); + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_1) + .put(FRAGMENT_2, PLAN_NODE_2) + .buildOrThrow(), + ImmutableListMultimap.of()); + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .putAll(FRAGMENT_2, createSourceHandle(1), createSourceHandle(3)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_1) + .put(FRAGMENT_2, PLAN_NODE_2) + .buildOrThrow(), + ImmutableListMultimap.of()); + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .putAll(FRAGMENT_2, createSourceHandle(1), createSourceHandle(3)) + .putAll(FRAGMENT_3, createSourceHandle(4)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_1) + .put(FRAGMENT_2, PLAN_NODE_1) + .put(FRAGMENT_3, PLAN_NODE_2) + .buildOrThrow(), + ImmutableListMultimap.of()); + // multiple source handles and splits + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .build(), + ImmutableMap.of(FRAGMENT_1, PLAN_NODE_1), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_3, createSplit(0)) + .putAll(PLAN_NODE_4, createSplit(0)) + .build()); + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .putAll(FRAGMENT_2, createSourceHandle(1), createSourceHandle(3)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_3) + .put(FRAGMENT_2, PLAN_NODE_4) + .buildOrThrow(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_1, createSplit(0), createSplit(3), createSplit(4)) + .putAll(PLAN_NODE_2, createSplit(0), createSplit(1)) + .build()); + testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .putAll(FRAGMENT_2, createSourceHandle(1), createSourceHandle(3)) + .putAll(FRAGMENT_3, createSourceHandle(4)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_1) + .put(FRAGMENT_2, PLAN_NODE_1) + .put(FRAGMENT_3, PLAN_NODE_2) + .buildOrThrow(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_3, createSplit(0), createSplit(3), createSplit(4)) + .putAll(PLAN_NODE_4, createSplit(0), createSplit(1)) + .build()); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void stressTest() + { + Set allFragments = ImmutableSet.of(FRAGMENT_1, FRAGMENT_2, FRAGMENT_3); + Map remoteSources = ImmutableMap.of(FRAGMENT_1, PLAN_NODE_1, FRAGMENT_2, PLAN_NODE_1, FRAGMENT_3, PLAN_NODE_2); + Set splitSources = ImmutableSet.of(PLAN_NODE_3, PLAN_NODE_4); + + ListMultimap sourceHandles = ArrayListMultimap.create(); + for (PlanFragmentId fragmentId : allFragments) { + int numberOfHandles = ThreadLocalRandom.current().nextInt(100); + for (int i = 0; i < numberOfHandles; i++) { + int partition = ThreadLocalRandom.current().nextInt(10); + sourceHandles.put(fragmentId, createSourceHandle(partition)); + } + } + + ListMultimap splits = ArrayListMultimap.create(); + for (PlanNodeId planNodeId : splitSources) { + int numberOfSplits = ThreadLocalRandom.current().nextInt(100); + for (int i = 0; i < numberOfSplits; i++) { + int partition = ThreadLocalRandom.current().nextInt(10); + splits.put(planNodeId, createSplit(partition)); + } + } + + testStageTaskSourceSuccess(sourceHandles, remoteSources, splits); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testFailures() + { + RuntimeException failure = new RuntimeException(); + testStageTaskSourceFailure( + Optional.of(new FailingExchangeSourceHandleSource(failure, false, executor)), + Optional.empty(), + Optional.empty(), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.of(new FailingExchangeSourceHandleSource(failure, true, executor)), + Optional.empty(), + Optional.empty(), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.empty(), + Optional.of(new FailingSplitSource(failure, false, executor)), + Optional.empty(), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.empty(), + Optional.of(new FailingSplitSource(failure, true, executor)), + Optional.empty(), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.empty(), + Optional.empty(), + Optional.of(new FailingSplitAssigner(Optional.of(failure), Optional.empty())), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.empty(), + Optional.empty(), + Optional.of(new FailingSplitAssigner(Optional.empty(), Optional.of(failure))), + Optional.empty(), + failure); + testStageTaskSourceFailure( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FailingTaskSourceCallback(failure)), + failure); + } + + private void testStageTaskSourceSuccess( + ListMultimap sourceHandles, + Map remoteSources, + ListMultimap splits) + { + testStageTaskSourceSuccess( + sourceHandles, + remoteSources, + splits, + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + Optional.empty()); + } + + private void testStageTaskSourceFailure( + Optional failingHandleSource, + Optional failingSplitSource, + Optional failingSplitAssigner, + Optional failingCallback, + RuntimeException expectedFailure) + { + assertThatThrownBy(() -> testStageTaskSourceSuccess( + ImmutableListMultimap.builder() + .putAll(FRAGMENT_1, createSourceHandle(1), createSourceHandle(1)) + .putAll(FRAGMENT_2, createSourceHandle(1), createSourceHandle(3)) + .build(), + ImmutableMap.builder() + .put(FRAGMENT_1, PLAN_NODE_1) + .put(FRAGMENT_2, PLAN_NODE_1) + .put(FRAGMENT_3, PLAN_NODE_2) + .buildOrThrow(), + ImmutableListMultimap.builder() + .putAll(PLAN_NODE_3, createSplit(0), createSplit(3), createSplit(4)) + .build(), + failingHandleSource.map(source -> ImmutableMap.of(FRAGMENT_3, source)).orElse(ImmutableMap.of()), + failingSplitSource.map(source -> ImmutableMap.of(PLAN_NODE_4, source)).orElse(ImmutableMap.of()), + failingSplitAssigner, + failingCallback)) + .isEqualTo(expectedFailure); + } + + private void testStageTaskSourceSuccess( + ListMultimap sourceHandles, + Map remoteSources, + ListMultimap splits, + Map failingHandleSources, + Map failingSplitSources, + Optional failingSplitAssigner, + Optional failingCallback) + { + List handleSources = new ArrayList<>(); + Map exchanges = new HashMap<>(); + Multimaps.asMap(sourceHandles).forEach(((fragmentId, handles) -> { + TestingExchangeSourceHandleSource handleSource = new TestingExchangeSourceHandleSource(executor, handles); + handleSources.add(handleSource); + exchanges.put(fragmentId, new TestingExchange(handleSource)); + })); + failingHandleSources.forEach(((fragmentId, handleSource) -> { + handleSources.add(handleSource); + exchanges.put(fragmentId, new TestingExchange(handleSource)); + })); + remoteSources.keySet().forEach(fragmentId -> { + if (!exchanges.containsKey(fragmentId)) { + TestingExchangeSourceHandleSource handleSource = new TestingExchangeSourceHandleSource(executor, ImmutableList.of()); + handleSources.add(handleSource); + exchanges.put(fragmentId, new TestingExchange(handleSource)); + } + }); + + Map splitSources = new HashMap<>(); + Multimaps.asMap(splits).forEach(((planNodeId, connectorSplits) -> splitSources.put(planNodeId, new TestingSplitSource(executor, connectorSplits)))); + splitSources.putAll(failingSplitSources); + + EventDrivenTaskSource.Callback taskSourceCallback = failingCallback.orElse(new TestingTaskSourceCallback()); + int partitionCount = getPartitionCount(sourceHandles.values(), splits.values()); + FaultTolerantPartitioningScheme partitioningScheme = createPartitioningScheme(partitionCount); + AtomicLong getSplitInvocations = new AtomicLong(); + Set allSources = ImmutableSet.builder() + .addAll(remoteSources.values()) + .addAll(splits.keySet()) + .build(); + List taskDescriptors = null; + RuntimeException failure = null; + TestingSplitAssigner testingSplitAssigner = new TestingSplitAssigner(allSources); + try (EventDrivenTaskSource taskSource = new EventDrivenTaskSource( + exchanges, + remoteSources, + () -> splitSources, + failingSplitAssigner.orElse(testingSplitAssigner), + taskSourceCallback, + executor, + 1, + 1, + partitioningScheme, + (getSplitDuration) -> getSplitInvocations.incrementAndGet())) { + taskSource.start(); + try { + if (taskSourceCallback instanceof FailingTaskSourceCallback callback) { + taskDescriptors = callback.getTaskDescriptors().get(1, MINUTES); + } + else if (taskSourceCallback instanceof TestingTaskSourceCallback callback) { + taskDescriptors = callback.getTaskDescriptorsFuture().get(1, MINUTES); + } + else { + fail("unexpected callback: " + taskSourceCallback.getClass()); + } + assertTrue(testingSplitAssigner.isFinished()); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + failure = new RuntimeException(e); + } + catch (ExecutionException e) { + if (e.getCause() instanceof RuntimeException runtimeException) { + failure = runtimeException; + } + else { + failure = new RuntimeException(e); + } + } + catch (TimeoutException e) { + failure = new UncheckedTimeoutException(e); + } + } + + for (ExchangeSourceHandleSource handleSource : handleSources) { + if (handleSource instanceof TestingExchangeSourceHandleSource source) { + assertTrue(source.isClosed()); + } + else if (handleSource instanceof FailingExchangeSourceHandleSource source) { + assertTrue(source.isClosed()); + } + else { + fail("unexpected handle source: " + handleSource.getClass()); + } + } + for (SplitSource splitSource : splitSources.values()) { + if (splitSource instanceof TestingSplitSource source) { + assertTrue(source.isClosed()); + } + else if (splitSource instanceof FailingSplitSource source) { + assertTrue(source.isClosed()); + } + else { + fail("unexpected split source: " + splitSource.getClass()); + } + } + + if (failure != null) { + throw failure; + } + + assertThat(taskDescriptors) + .isNotNull() + .isNotEmpty(); + + Map> expectedHandles = new HashMap<>(); + Map> expectedSplits = new HashMap<>(); + for (Map.Entry entry : sourceHandles.entries()) { + TestingExchangeSourceHandle handle = (TestingExchangeSourceHandle) entry.getValue(); + PlanNodeId planNodeId = remoteSources.get(entry.getKey()); + expectedHandles.computeIfAbsent(handle.getPartitionId(), key -> HashMultimap.create()).put(planNodeId, handle); + } + for (Map.Entry entry : splits.entries()) { + TestingConnectorSplit split = (TestingConnectorSplit) entry.getValue(); + expectedSplits.computeIfAbsent(split.getBucket().orElseThrow(), key -> HashMultimap.create()).put(entry.getKey(), split); + } + + Map> actualHandles = new HashMap<>(); + Map> actualSplits = new HashMap<>(); + for (TaskDescriptor taskDescriptor : taskDescriptors) { + int partitionId = taskDescriptor.getPartitionId(); + for (Map.Entry entry : taskDescriptor.getSplits().entries()) { + if (entry.getValue().getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { + RemoteSplit remoteSplit = (RemoteSplit) entry.getValue().getConnectorSplit(); + SpoolingExchangeInput input = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); + for (ExchangeSourceHandle handle : input.getExchangeSourceHandles()) { + assertEquals(handle.getPartitionId(), partitionId); + actualHandles.computeIfAbsent(partitionId, key -> HashMultimap.create()).put(entry.getKey(), (TestingExchangeSourceHandle) handle); + } + } + else { + TestingConnectorSplit split = (TestingConnectorSplit) entry.getValue().getConnectorSplit(); + assertEquals(split.getBucket().orElseThrow(), partitionId); + actualSplits.computeIfAbsent(partitionId, key -> HashMultimap.create()).put(entry.getKey(), split); + } + } + } + + assertEquals(actualHandles, expectedHandles); + assertEquals(actualSplits, expectedSplits); + } + + private static FaultTolerantPartitioningScheme createPartitioningScheme(int partitionCount) + { + return new FaultTolerantPartitioningScheme( + partitionCount, + Optional.of(IntStream.range(0, partitionCount).toArray()), + Optional.of(split -> ((TestingConnectorSplit) split.getConnectorSplit()).getBucket().orElseThrow()), + Optional.empty()); + } + + private static int getPartitionCount(Collection sourceHandles, Collection splits) + { + int maxPartitionId = sourceHandles.stream() + .mapToInt(ExchangeSourceHandle::getPartitionId) + .max() + .orElse(-1); + maxPartitionId = max(maxPartitionId, splits.stream() + .map(TestingConnectorSplit.class::cast) + .map(TestingConnectorSplit::getBucket) + .mapToInt(OptionalInt::orElseThrow) + .max() + .orElse(-1)); + return max(maxPartitionId + 1, 1); + } + + private TestingExchangeSourceHandle createSourceHandle(int partitionId) + { + return new TestingExchangeSourceHandle(nextId.getAndIncrement(), partitionId, 0); + } + + private TestingConnectorSplit createSplit(int partitionId) + { + return new TestingConnectorSplit(nextId.getAndIncrement(), OptionalInt.of(partitionId), Optional.empty()); + } + + private static class TestingSplitSource + implements SplitSource + { + private final ScheduledExecutorService executor; + @GuardedBy("this") + private final Queue remainingSplits; + @GuardedBy("this") + private SettableFuture currentFuture; + @GuardedBy("this") + private boolean finished; + @GuardedBy("this") + private boolean closed; + + public TestingSplitSource(ScheduledExecutorService executor, List splits) + { + this.executor = requireNonNull(executor, "executor is null"); + remainingSplits = new LinkedList<>(splits); + } + + @Override + public CatalogHandle getCatalogHandle() + { + return TESTING_CATALOG_HANDLE; + } + + @Override + public synchronized ListenableFuture getNextBatch(int maxSize) + { + checkState(!closed, "closed"); + checkState(currentFuture == null || currentFuture.isDone(), "currentFuture is still running"); + currentFuture = SettableFuture.create(); + long delay = ThreadLocalRandom.current().nextInt(3); + if (delay == 0) { + setNextBatch(); + } + else { + executor.schedule(this::setNextBatch, delay, MILLISECONDS); + } + return currentFuture; + } + + private void setNextBatch() + { + SettableFuture future; + SplitBatch batch; + synchronized (this) { + future = currentFuture; + ConnectorSplit split = remainingSplits.poll(); + boolean lastBatch = remainingSplits.isEmpty(); + batch = new SplitBatch(split == null ? ImmutableList.of() : ImmutableList.of(new Split(TESTING_CATALOG_HANDLE, split)), lastBatch); + if (lastBatch) { + finished = true; + } + } + if (future != null) { + future.set(batch); + } + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + if (currentFuture != null) { + currentFuture.cancel(true); + currentFuture = null; + } + remainingSplits.clear(); + } + + @Override + public synchronized boolean isFinished() + { + return finished || closed; + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + return Optional.empty(); + } + + public synchronized boolean isClosed() + { + return closed; + } + } + + private static class FailingSplitSource + implements SplitSource + { + private final RuntimeException failure; + private final boolean failFuture; + private final ScheduledExecutorService executor; + private final AtomicBoolean closed = new AtomicBoolean(); + + private FailingSplitSource(RuntimeException failure, boolean failFuture, ScheduledExecutorService executor) + { + this.failure = requireNonNull(failure, "failure is null"); + this.failFuture = failFuture; + this.executor = requireNonNull(executor, "executor is null"); + } + + @Override + public CatalogHandle getCatalogHandle() + { + throw new UnsupportedOperationException(); + } + + @Override + public ListenableFuture getNextBatch(int maxSize) + { + if (!failFuture) { + throw failure; + } + SettableFuture future = SettableFuture.create(); + long delay = ThreadLocalRandom.current().nextInt(3); + if (delay == 0) { + future.setException(failure); + } + else { + executor.schedule(() -> future.setException(failure), delay, MILLISECONDS); + } + return future; + } + + @Override + public void close() + { + closed.set(true); + } + + public boolean isClosed() + { + return closed.get(); + } + + @Override + public boolean isFinished() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + throw new UnsupportedOperationException(); + } + } + + private static class TestingExchangeSourceHandleSource + implements ExchangeSourceHandleSource + { + private final ScheduledExecutorService executor; + @GuardedBy("this") + private final Queue remainingHandles; + @GuardedBy("this") + private CompletableFuture currentFuture; + @GuardedBy("this") + private boolean closed; + + private TestingExchangeSourceHandleSource(ScheduledExecutorService executor, List handles) + { + this.executor = requireNonNull(executor, "executor is null"); + this.remainingHandles = new LinkedList<>(requireNonNull(handles, "handles is null")); + } + + @Override + public synchronized CompletableFuture getNextBatch() + { + checkState(!closed, "closed"); + checkState(currentFuture == null || currentFuture.isDone(), "currentFuture is still running"); + currentFuture = new CompletableFuture<>(); + long delay = ThreadLocalRandom.current().nextInt(3); + if (delay == 0) { + setNextBatch(); + } + else { + executor.schedule(this::setNextBatch, delay, MILLISECONDS); + } + return currentFuture; + } + + private void setNextBatch() + { + CompletableFuture future; + ExchangeSourceHandleBatch batch; + synchronized (this) { + future = currentFuture; + ExchangeSourceHandle handle = remainingHandles.poll(); + boolean lastBatch = remainingHandles.isEmpty(); + batch = new ExchangeSourceHandleBatch(handle == null ? ImmutableList.of() : ImmutableList.of(handle), lastBatch); + } + if (future != null) { + future.complete(batch); + } + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + if (currentFuture != null) { + currentFuture.cancel(true); + currentFuture = null; + } + remainingHandles.clear(); + } + + public synchronized boolean isClosed() + { + return closed; + } + } + + private static class FailingExchangeSourceHandleSource + implements ExchangeSourceHandleSource + { + private final RuntimeException failure; + private final boolean failFuture; + private final ScheduledExecutorService executor; + private final AtomicBoolean closed = new AtomicBoolean(); + + public FailingExchangeSourceHandleSource(RuntimeException failure, boolean failFuture, ScheduledExecutorService executor) + { + this.failure = requireNonNull(failure, "failure is null"); + this.failFuture = failFuture; + this.executor = requireNonNull(executor, "executor is null"); + } + + @Override + public CompletableFuture getNextBatch() + { + if (!failFuture) { + throw failure; + } + CompletableFuture future = new CompletableFuture<>(); + long delay = ThreadLocalRandom.current().nextInt(3); + if (delay == 0) { + future.completeExceptionally(failure); + } + else { + executor.schedule(() -> future.completeExceptionally(failure), delay, MILLISECONDS); + } + return future; + } + + @Override + public void close() + { + closed.set(true); + } + + public boolean isClosed() + { + return closed.get(); + } + } + + private static class FailingTaskSourceCallback + implements EventDrivenTaskSource.Callback + { + private final RuntimeException failure; + private final SettableFuture> taskDescriptors = SettableFuture.create(); + + private FailingTaskSourceCallback(RuntimeException failure) + { + this.failure = requireNonNull(failure, "failure is null"); + } + + public SettableFuture> getTaskDescriptors() + { + return taskDescriptors; + } + + @Override + public void partitionsAdded(List partitions) + { + throw failure; + } + + @Override + public void noMorePartitions() + { + throw failure; + } + + @Override + public void partitionsUpdated(List partitionUpdates) + { + throw failure; + } + + @Override + public void partitionsSealed(ImmutableIntArray partitionIds) + { + throw failure; + } + + @Override + public void failed(Throwable t) + { + taskDescriptors.setException(t); + } + } + + private static class TestingExchange + implements Exchange + { + @GuardedBy("this") + private ExchangeSourceHandleSource exchangeSourceHandleSource; + @GuardedBy("this") + private boolean closed; + + public TestingExchange(ExchangeSourceHandleSource exchangeSourceHandleSource) + { + this.exchangeSourceHandleSource = requireNonNull(exchangeSourceHandleSource, "exchangeSourceHandleSource is null"); + } + + @Override + public ExchangeId getId() + { + throw new UnsupportedOperationException(); + } + + @Override + public ExchangeSinkHandle addSink(int taskPartitionId) + { + throw new UnsupportedOperationException(); + } + + @Override + public void noMoreSinks() + { + throw new UnsupportedOperationException(); + } + + @Override + public ExchangeSinkInstanceHandle instantiateSink(ExchangeSinkHandle sinkHandle, int taskAttemptId) + { + throw new UnsupportedOperationException(); + } + + @Override + public ExchangeSinkInstanceHandle updateSinkInstanceHandle(ExchangeSinkHandle sinkHandle, int taskAttemptId) + { + throw new UnsupportedOperationException(); + } + + @Override + public void sinkFinished(ExchangeSinkHandle sinkHandle, int taskAttemptId) + { + throw new UnsupportedOperationException(); + } + + @Override + public void allRequiredSinksFinished() + { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized ExchangeSourceHandleSource getSourceHandles() + { + checkState(!closed, "already closed"); + checkState(exchangeSourceHandleSource != null, "already retrieved"); + ExchangeSourceHandleSource result = exchangeSourceHandleSource; + exchangeSourceHandleSource = null; + return result; + } + + @Override + public synchronized void close() + { + closed = true; + } + } + + private static class TestingSplitAssigner + implements SplitAssigner + { + private final Set allSources; + + private final Set partitions = new HashSet<>(); + private final Set finishedSources = new HashSet<>(); + + private boolean finished; + + private TestingSplitAssigner(Set allSources) + { + this.allSources = ImmutableSet.copyOf(requireNonNull(allSources, "allSources is null")); + } + + @Override + public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splitsMap, boolean noMoreSplits) + { + checkState(!finished, "finished is set"); + AssignmentResult.Builder result = AssignmentResult.builder(); + Multimaps.asMap(splitsMap).forEach((partition, splits) -> { + if (partitions.add(partition)) { + result.addPartition(new Partition(partition, new NodeRequirements(Optional.empty(), ImmutableSet.of()))); + for (PlanNodeId finishedSource : finishedSources) { + result.updatePartition(new PartitionUpdate(partition, finishedSource, ImmutableList.of(), true)); + } + } + result.updatePartition(new PartitionUpdate(partition, planNodeId, splits, noMoreSplits)); + }); + if (noMoreSplits) { + finishedSources.add(planNodeId); + for (Integer partition : partitions) { + result.updatePartition(new PartitionUpdate(partition, planNodeId, ImmutableList.of(), true)); + } + } + if (finishedSources.containsAll(allSources)) { + partitions.forEach(result::sealPartition); + } + return result.build(); + } + + @Override + public AssignmentResult finish() + { + AssignmentResult.Builder result = AssignmentResult.builder(); + if (finished) { + return result.build(); + } + finished = true; + + checkState(finishedSources.containsAll(allSources)); + if (partitions.isEmpty()) { + partitions.add(0); + result + .addPartition(new Partition(0, new NodeRequirements(Optional.empty(), ImmutableSet.of()))) + .sealPartition(0); + } + return result.setNoMorePartitions() + .build(); + } + + public boolean isFinished() + { + return finished; + } + } + + private static class FailingSplitAssigner + implements SplitAssigner + { + private final Optional assignFailure; + private final Optional finishFailure; + + private FailingSplitAssigner(Optional assignFailure, Optional finishFailure) + { + this.assignFailure = requireNonNull(assignFailure, "assignFailure is null"); + this.finishFailure = requireNonNull(finishFailure, "finishFailure is null"); + } + + @Override + public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) + { + if (assignFailure.isPresent()) { + throw assignFailure.get(); + } + return AssignmentResult.builder().build(); + } + + @Override + public AssignmentResult finish() + { + if (finishFailure.isPresent()) { + throw finishFailure.get(); + } + return AssignmentResult.builder().build(); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index dbc83a99d4a4..7935ade157ce 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -36,7 +36,6 @@ import io.trino.execution.TestingRemoteTaskFactory; import io.trino.execution.TestingRemoteTaskFactory.TestingRemoteTask; import io.trino.execution.scheduler.TestingExchange.TestingExchangeSinkHandle; -import io.trino.execution.scheduler.TestingExchange.TestingExchangeSourceHandle; import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; import io.trino.failuredetector.NoOpFailureDetector; import io.trino.metadata.InternalNode; @@ -200,12 +199,12 @@ public void testHappyPath() // blocked on first source exchange assertBlocked(blocked); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); // still blocked on the second source exchange assertBlocked(blocked); assertFalse(scheduler.isBlocked().isDone()); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); // now unblocked assertUnblocked(blocked); assertUnblocked(scheduler.isBlocked()); @@ -343,8 +342,8 @@ public void testTasksWaitingForNodes() 2, 3); // allow for 3 tasks waiting for nodes before blocking - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); scheduler.schedule(); Map tasks; @@ -421,8 +420,8 @@ public void testTaskFailure() 0, 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); assertUnblocked(scheduler.isBlocked()); scheduler.schedule(); @@ -481,8 +480,8 @@ public void testRetryDelay() 6, 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); assertUnblocked(scheduler.isBlocked()); scheduler.schedule(); @@ -756,8 +755,8 @@ private void testCancellation(boolean abort) 0, 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); assertUnblocked(scheduler.isBlocked()); scheduler.schedule(); @@ -812,8 +811,8 @@ public void testAsyncTaskSource() 2, 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); assertUnblocked(scheduler.isBlocked()); scheduler.schedule(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java new file mode 100644 index 000000000000..589cd42d0563 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java @@ -0,0 +1,447 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.SetMultimap; +import com.google.common.primitives.ImmutableLongArray; +import io.trino.client.NodeVersion; +import io.trino.connector.CatalogHandle; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.sql.planner.plan.PlanNodeId; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.difference; +import static io.trino.connector.CatalogHandle.createRootCatalogHandle; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; + +public class TestHashDistributionSplitAssigner +{ + private static final CatalogHandle TESTING_CATALOG_HANDLE = createRootCatalogHandle("testing"); + + private static final PlanNodeId PARTITIONED_1 = new PlanNodeId("partitioned-1"); + private static final PlanNodeId PARTITIONED_2 = new PlanNodeId("partitioned-2"); + private static final PlanNodeId REPLICATED_1 = new PlanNodeId("replicated-1"); + private static final PlanNodeId REPLICATED_2 = new PlanNodeId("replicated-2"); + + private static final InternalNode NODE_1 = new InternalNode("node1", URI.create("http://localhost:8081"), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_2 = new InternalNode("node2", URI.create("http://localhost:8082"), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_3 = new InternalNode("node3", URI.create("http://localhost:8083"), NodeVersion.UNKNOWN, false); + + @Test + public void testEmpty() + { + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of(new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), + 10, + Optional.empty(), + 1024, + ImmutableMap.of(), + false, + 1); + testAssigner( + ImmutableSet.of(), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of(new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)), + 1, + Optional.empty(), + 1024, + ImmutableMap.of(REPLICATED_1, new OutputDataSizeEstimate(ImmutableLongArray.builder().add(0).build())), + false, + 1); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true), + new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)), + 10, + Optional.empty(), + 1024, + ImmutableMap.of(), + false, + 1); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true), + new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true), + new SplitBatch(PARTITIONED_2, ImmutableListMultimap.of(), true), + new SplitBatch(REPLICATED_2, ImmutableListMultimap.of(), true)), + 10, + Optional.empty(), + 1024, + ImmutableMap.of(), + false, + 1); + } + + @Test + public void testExplicitPartitionToNodeMap() + { + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), + 3, + Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 3); + // some partitions missing + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), + 3, + Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 1); + // no splits + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), + 3, + Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 1); + } + + @Test + public void testPreserveOutputPartitioning() + { + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + true, + 3); + // some partitions missing + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + true, + 1); + // no splits + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + true, + 1); + } + + @Test + public void testMissingEstimates() + { + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(), + false, + 3); + // some partitions missing + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(), + false, + 1); + // no splits + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), + 3, + Optional.empty(), + 1000, + ImmutableMap.of(), + false, + 1); + } + + @Test + public void testHappyPath() + { + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), + 3, + Optional.empty(), + 3, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 1); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), + 3, + Optional.empty(), + 3, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 1); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), + 3, + Optional.empty(), + 1, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 3); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(REPLICATED_1), + ImmutableList.of( + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), + 3, + Optional.empty(), + 1, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 3); + testAssigner( + ImmutableSet.of(PARTITIONED_1), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(REPLICATED_2, createSplitMap(createSplit(11, 1), createSplit(12, 100)), true), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), + 3, + Optional.empty(), + 1, + ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 3); + testAssigner( + ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), + ImmutableSet.of(REPLICATED_1, REPLICATED_2), + ImmutableList.of( + new SplitBatch(REPLICATED_2, createSplitMap(createSplit(11, 1), createSplit(12, 100)), true), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_2, createSplitMap(), true), + new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), + 3, + Optional.empty(), + 1, + ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)), + PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), + false, + 3); + } + + private static void testAssigner( + Set partitionedSources, + Set replicatedSources, + List batches, + int splitPartitionCount, + Optional> partitionToNodeMap, + long targetPartitionSizeInBytes, + Map outputDataSizeEstimates, + boolean preserveOutputPartitioning, + int expectedTaskCount) + { + FaultTolerantPartitioningScheme partitioningScheme = createPartitioningScheme(splitPartitionCount, partitionToNodeMap); + HashDistributionSplitAssigner assigner = new HashDistributionSplitAssigner( + Optional.of(TESTING_CATALOG_HANDLE), + partitionedSources, + replicatedSources, + targetPartitionSizeInBytes, + outputDataSizeEstimates, + partitioningScheme, + preserveOutputPartitioning); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + SetMultimap partitionedSplitIds = HashMultimap.create(); + Set replicatedSplitIds = new HashSet<>(); + for (SplitBatch batch : batches) { + assigner.assign(batch.getPlanNodeId(), batch.getSplits(), batch.isNoMoreSplits()).update(callback); + boolean replicated = replicatedSources.contains(batch.getPlanNodeId()); + callback.checkContainsSplits(batch.getPlanNodeId(), batch.getSplits().values(), replicated); + for (Map.Entry entry : batch.getSplits().entries()) { + int splitId = TestingConnectorSplit.getSplitId(entry.getValue()); + if (replicated) { + assertThat(replicatedSplitIds).doesNotContain(splitId); + replicatedSplitIds.add(splitId); + } + else { + partitionedSplitIds.put(entry.getKey(), splitId); + } + } + } + assigner.finish().update(callback); + List taskDescriptors = callback.getTaskDescriptors(); + assertThat(taskDescriptors).hasSize(expectedTaskCount); + for (TaskDescriptor taskDescriptor : taskDescriptors) { + int partitionId = taskDescriptor.getPartitionId(); + NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); + assertEquals(nodeRequirements.getCatalogHandle(), Optional.of(TESTING_CATALOG_HANDLE)); + partitionToNodeMap.ifPresent(partitionToNode -> { + if (!taskDescriptor.getSplits().isEmpty()) { + InternalNode node = partitionToNode.get(partitionId); + assertThat(nodeRequirements.getAddresses()).containsExactly(node.getHostAndPort()); + } + }); + Set taskDescriptorSplitIds = taskDescriptor.getSplits().values().stream() + .map(TestingConnectorSplit::getSplitId) + .collect(toImmutableSet()); + assertThat(taskDescriptorSplitIds).containsAll(replicatedSplitIds); + Set taskDescriptorPartitionedSplitIds = difference(taskDescriptorSplitIds, replicatedSplitIds); + Set taskDescriptorSplitPartitions = new HashSet<>(); + for (Split split : taskDescriptor.getSplits().values()) { + int splitId = TestingConnectorSplit.getSplitId(split); + if (taskDescriptorPartitionedSplitIds.contains(splitId)) { + int splitPartition = partitioningScheme.getPartition(split); + taskDescriptorSplitPartitions.add(splitPartition); + } + } + for (Integer splitPartition : taskDescriptorSplitPartitions) { + assertThat(taskDescriptorPartitionedSplitIds).containsAll(partitionedSplitIds.get(splitPartition)); + } + } + } + + private static ListMultimap createSplitMap(Split... splits) + { + return Arrays.stream(splits) + .collect(toImmutableListMultimap(split -> ((TestingConnectorSplit) split.getConnectorSplit()).getBucket().orElseThrow(), Function.identity())); + } + + private static FaultTolerantPartitioningScheme createPartitioningScheme(int partitionCount, Optional> partitionToNodeMap) + { + return new FaultTolerantPartitioningScheme( + partitionCount, + Optional.of(IntStream.range(0, partitionCount).toArray()), + Optional.of(split -> ((TestingConnectorSplit) split.getConnectorSplit()).getBucket().orElseThrow()), + partitionToNodeMap); + } + + private static Split createSplit(int id, int partition) + { + return new Split(TESTING_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.of(partition), Optional.empty())); + } + + private static class SplitBatch + { + private final PlanNodeId planNodeId; + private final ListMultimap splits; + private final boolean noMoreSplits; + + public SplitBatch(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) + { + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); + this.noMoreSplits = noMoreSplits; + } + + public PlanNodeId getPlanNodeId() + { + return planNodeId; + } + + public ListMultimap getSplits() + { + return splits; + } + + public boolean isNoMoreSplits() + { + return noMoreSplits; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSingleDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSingleDistributionSplitAssigner.java new file mode 100644 index 000000000000..588e125a444b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSingleDistributionSplitAssigner.java @@ -0,0 +1,145 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableSet; +import io.trino.connector.CatalogHandle; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.sql.planner.plan.PlanNodeId; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.OptionalInt; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestSingleDistributionSplitAssigner +{ + private static final CatalogHandle TESTING_CATALOG_HANDLE = CatalogHandle.createRootCatalogHandle("testing"); + + private static final PlanNodeId PLAN_NODE_1 = new PlanNodeId("plan-node-1"); + private static final PlanNodeId PLAN_NODE_2 = new PlanNodeId("plan-node-2"); + + @Test + public void testNoSources() + { + ImmutableSet hostRequirement = ImmutableSet.of(HostAddress.fromParts("localhost", 8080)); + SplitAssigner splitAssigner = new SingleDistributionSplitAssigner(hostRequirement, ImmutableSet.of()); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + + splitAssigner.finish().update(callback); + + assertEquals(callback.getPartitionCount(), 1); + assertEquals(callback.getNodeRequirements(0), new NodeRequirements(Optional.empty(), hostRequirement)); + assertTrue(callback.isSealed(0)); + assertTrue(callback.isNoMorePartitions()); + } + + @Test + public void testEmptySource() + { + ImmutableSet hostRequirement = ImmutableSet.of(HostAddress.fromParts("localhost", 8080)); + SplitAssigner splitAssigner = new SingleDistributionSplitAssigner( + hostRequirement, + ImmutableSet.of(PLAN_NODE_1)); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(), true).update(callback); + splitAssigner.finish().update(callback); + + assertEquals(callback.getPartitionCount(), 1); + assertEquals(callback.getNodeRequirements(0), new NodeRequirements(Optional.empty(), hostRequirement)); + assertThat(callback.getSplitIds(0, PLAN_NODE_1)).isEmpty(); + assertTrue(callback.isNoMoreSplits(0, PLAN_NODE_1)); + assertTrue(callback.isSealed(0)); + assertTrue(callback.isNoMorePartitions()); + } + + @Test + public void testSingleSource() + { + SplitAssigner splitAssigner = new SingleDistributionSplitAssigner( + ImmutableSet.of(), + ImmutableSet.of(PLAN_NODE_1)); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + + assertEquals(callback.getPartitionCount(), 0); + assertFalse(callback.isNoMorePartitions()); + + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(1)), false).update(callback); + splitAssigner.finish().update(callback); + assertEquals(callback.getPartitionCount(), 1); + assertThat(callback.getSplitIds(0, PLAN_NODE_1)).containsExactly(1); + assertTrue(callback.isNoMorePartitions()); + + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(2), 1, createSplit(3)), false).update(callback); + splitAssigner.finish().update(callback); + assertEquals(callback.getPartitionCount(), 1); + assertThat(callback.getSplitIds(0, PLAN_NODE_1)).containsExactly(1, 2, 3); + + assertFalse(callback.isNoMoreSplits(0, PLAN_NODE_1)); + assertFalse(callback.isSealed(0)); + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(4)), true).update(callback); + splitAssigner.finish().update(callback); + assertTrue(callback.isNoMoreSplits(0, PLAN_NODE_1)); + assertTrue(callback.isSealed(0)); + } + + @Test + public void testMultipleSources() + { + SplitAssigner splitAssigner = new SingleDistributionSplitAssigner( + ImmutableSet.of(), + ImmutableSet.of(PLAN_NODE_1, PLAN_NODE_2)); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + + assertEquals(callback.getPartitionCount(), 0); + assertFalse(callback.isNoMorePartitions()); + + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(1)), false).update(callback); + splitAssigner.finish().update(callback); + assertEquals(callback.getPartitionCount(), 1); + assertThat(callback.getSplitIds(0, PLAN_NODE_1)).containsExactly(1); + assertTrue(callback.isNoMorePartitions()); + + splitAssigner.assign(PLAN_NODE_2, ImmutableListMultimap.of(0, createSplit(2), 1, createSplit(3)), false).update(callback); + splitAssigner.finish().update(callback); + assertEquals(callback.getPartitionCount(), 1); + assertThat(callback.getSplitIds(0, PLAN_NODE_2)).containsExactly(2, 3); + + assertFalse(callback.isNoMoreSplits(0, PLAN_NODE_1)); + splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(2, createSplit(4)), true).update(callback); + splitAssigner.finish().update(callback); + assertThat(callback.getSplitIds(0, PLAN_NODE_1)).containsExactly(1, 4); + assertTrue(callback.isNoMoreSplits(0, PLAN_NODE_1)); + + assertFalse(callback.isNoMoreSplits(0, PLAN_NODE_2)); + assertFalse(callback.isSealed(0)); + splitAssigner.assign(PLAN_NODE_2, ImmutableListMultimap.of(3, createSplit(5)), true).update(callback); + splitAssigner.finish().update(callback); + assertThat(callback.getSplitIds(0, PLAN_NODE_2)).containsExactly(2, 3, 5); + assertTrue(callback.isNoMoreSplits(0, PLAN_NODE_2)); + assertTrue(callback.isSealed(0)); + } + + private Split createSplit(int id) + { + return new Split(TESTING_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.empty())); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java index 58d256e66c5a..413922938be9 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -31,7 +31,6 @@ import io.trino.execution.scheduler.StageTaskSourceFactory.HashDistributionTaskSource; import io.trino.execution.scheduler.StageTaskSourceFactory.SingleDistributionTaskSource; import io.trino.execution.scheduler.StageTaskSourceFactory.SourceDistributionTaskSource; -import io.trino.execution.scheduler.TestingExchange.TestingExchangeSourceHandle; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.InternalNode; import io.trino.metadata.InternalNodeManager; @@ -39,12 +38,10 @@ import io.trino.spi.HostAddress; import io.trino.spi.QueryId; import io.trino.spi.SplitWeight; -import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.exchange.ExchangeSourceHandle; import io.trino.split.RemoteSplit; import io.trino.split.SplitSource; import io.trino.sql.planner.plan.PlanNodeId; -import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; import java.net.URI; @@ -52,12 +49,10 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.stream.IntStream; -import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import static com.google.common.collect.Iterables.getOnlyElement; @@ -65,14 +60,11 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.getDone; import static io.airlift.concurrent.MoreFutures.getFutureValue; -import static io.airlift.slice.SizeOf.estimatedSizeOf; -import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.trino.execution.scheduler.StageTaskSourceFactory.createRemoteSplits; import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; -import static java.lang.Math.toIntExact; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -95,9 +87,9 @@ public class TestStageTaskSourceFactory public void testSingleDistributionTaskSource() { ListMultimap sources = ImmutableListMultimap.builder() - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)) - .put(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)) - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 222)) + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 123)) + .put(PLAN_NODE_2, new TestingExchangeSourceHandle(1, 0, 321)) + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(2, 0, 222)) .build(); TaskSource taskSource = new SingleDistributionTaskSource(createRemoteSplits(sources), new InMemoryNodeManager(), false); @@ -119,9 +111,9 @@ public void testSingleDistributionTaskSource() public void testCoordinatorDistributionTaskSource() { ListMultimap sources = ImmutableListMultimap.builder() - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)) - .put(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)) - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 222)) + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 123)) + .put(PLAN_NODE_2, new TestingExchangeSourceHandle(1, 0, 321)) + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(2, 0, 222)) .build(); InternalNodeManager nodeManager = new InMemoryNodeManager(); TaskSource taskSource = new SingleDistributionTaskSource(createRemoteSplits(sources), nodeManager, true); @@ -152,12 +144,12 @@ public void testArbitraryDistributionTaskSource() assertThat(tasks).isEmpty(); assertTrue(taskSource.isFinished()); - TestingExchangeSourceHandle sourceHandle1 = new TestingExchangeSourceHandle(0, 1); - TestingExchangeSourceHandle sourceHandle2 = new TestingExchangeSourceHandle(0, 2); - TestingExchangeSourceHandle sourceHandle3 = new TestingExchangeSourceHandle(0, 3); - TestingExchangeSourceHandle sourceHandle4 = new TestingExchangeSourceHandle(0, 4); - TestingExchangeSourceHandle sourceHandle123 = new TestingExchangeSourceHandle(0, 123); - TestingExchangeSourceHandle sourceHandle321 = new TestingExchangeSourceHandle(0, 321); + TestingExchangeSourceHandle sourceHandle1 = new TestingExchangeSourceHandle(0, 0, 1); + TestingExchangeSourceHandle sourceHandle2 = new TestingExchangeSourceHandle(1, 0, 2); + TestingExchangeSourceHandle sourceHandle3 = new TestingExchangeSourceHandle(2, 0, 3); + TestingExchangeSourceHandle sourceHandle4 = new TestingExchangeSourceHandle(3, 0, 4); + TestingExchangeSourceHandle sourceHandle123 = new TestingExchangeSourceHandle(4, 0, 123); + TestingExchangeSourceHandle sourceHandle321 = new TestingExchangeSourceHandle(5, 0, 321); Multimap nonReplicatedSources = ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3); taskSource = new ArbitraryDistributionTaskSource( nonReplicatedSources, @@ -168,7 +160,7 @@ public void testArbitraryDistributionTaskSource() assertThat(tasks).hasSize(1); assertEquals(tasks.get(0).getPartitionId(), 0); assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3))); + assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3)); nonReplicatedSources = ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123); taskSource = new ArbitraryDistributionTaskSource( @@ -179,7 +171,7 @@ public void testArbitraryDistributionTaskSource() assertThat(tasks).hasSize(1); assertEquals(tasks.get(0).getPartitionId(), 0); assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123))); + assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123)); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle123, @@ -192,10 +184,10 @@ public void testArbitraryDistributionTaskSource() assertThat(tasks).hasSize(2); assertEquals(tasks.get(0).getPartitionId(), 0); assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123))); + assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123)); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321))); + assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle321)); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle1, @@ -212,11 +204,11 @@ public void testArbitraryDistributionTaskSource() assertEquals( extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2))); + PLAN_NODE_1, sourceHandle1, + PLAN_NODE_1, sourceHandle2)); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 4))); + assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle4)); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle1, @@ -230,13 +222,13 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), assertThat(tasks).hasSize(3); assertEquals(tasks.get(0).getPartitionId(), 0); assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle1)); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3))); + assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3)); assertEquals(tasks.get(2).getPartitionId(), 2); assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 4))); + assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle4)); // with replicated sources nonReplicatedSources = ImmutableListMultimap.of( @@ -256,15 +248,15 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), assertEquals( extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321))); + PLAN_NODE_1, sourceHandle1, + PLAN_NODE_1, sourceHandle2, + PLAN_NODE_2, sourceHandle321)); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertEquals( extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 4), + PLAN_NODE_1, sourceHandle4, PLAN_NODE_2, sourceHandle321)); } @@ -286,12 +278,12 @@ public void testHashDistributionTaskSource() taskSource = createHashDistributionTaskSource( ImmutableMap.of(), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1)), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1)), 1, createPartitioningScheme(4), 0, @@ -305,23 +297,23 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), assertEquals(extractSourceHandles( tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())); assertEquals( extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); assertEquals(tasks.get(2).getPartitionId(), 2); assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())); assertEquals( extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); Split bucketedSplit1 = createBucketedSplit(0, 0); Split bucketedSplit2 = createBucketedSplit(0, 2); @@ -334,7 +326,7 @@ PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucket PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), ImmutableListMultimap.of(), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1)), 1, createPartitioningScheme(4, 4), 0, @@ -346,31 +338,31 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), assertEquals(tasks.get(0).getPartitionId(), 0); assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit1)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); assertEquals(tasks.get(2).getPartitionId(), 2); assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); assertEquals(tasks.get(3).getPartitionId(), 3); assertEquals(tasks.get(3).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); - assertEquals(extractSourceHandles(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); taskSource = createHashDistributionTaskSource( ImmutableMap.of( PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1)), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1)), 1, createPartitioningScheme(4, 4), 0, @@ -383,36 +375,36 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit1)); assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_5, bucketedSplit4)); assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); assertEquals(tasks.get(2).getPartitionId(), 2); assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); assertEquals(tasks.get(3).getPartitionId(), 3); assertEquals(tasks.get(3).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); assertEquals(extractSourceHandles(tasks.get(3).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); taskSource = createHashDistributionTaskSource( ImmutableMap.of( PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1)), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1)), 2, createPartitioningScheme(2, 4), 0, DataSize.of(0, BYTE)); @@ -426,17 +418,17 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), PLAN_NODE_4, bucketedSplit1, PLAN_NODE_4, bucketedSplit2)); assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit3, PLAN_NODE_5, bucketedSplit4)); assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1))); // join based on split target split weight taskSource = createHashDistributionTaskSource( @@ -444,13 +436,13 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 1)), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1)), 2, createPartitioningScheme(4, 4), 2 * STANDARD_WEIGHT, @@ -465,19 +457,19 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), PLAN_NODE_4, bucketedSplit1, PLAN_NODE_5, bucketedSplit4)); assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit2, PLAN_NODE_4, bucketedSplit3)); assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1))); + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); // join based on target exchange size taskSource = createHashDistributionTaskSource( @@ -485,13 +477,13 @@ PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 20), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 30), - PLAN_NODE_2, new TestingExchangeSourceHandle(1, 20), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 99), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 30)), + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 20), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 30), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 20), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 99), + PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 30)), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1)), 2, createPartitioningScheme(4, 4), 100 * STANDARD_WEIGHT, @@ -506,22 +498,22 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), PLAN_NODE_4, bucketedSplit1, PLAN_NODE_5, bucketedSplit4)); assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 20), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 30), - PLAN_NODE_2, new TestingExchangeSourceHandle(1, 20), - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1))); + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 20), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 30), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 20), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); assertEquals(tasks.get(1).getPartitionId(), 1); assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 99), - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1))); + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 99), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); assertEquals(tasks.get(2).getPartitionId(), 2); assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 30), - PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1))); + PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 30), + PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); } private static HashDistributionTaskSource createHashDistributionTaskSource( @@ -591,7 +583,7 @@ public void testSourceDistributionTaskSource() PLAN_NODE_1, split3)); assertTrue(taskSource.isFinished()); - ImmutableListMultimap replicatedSources = ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)); + ImmutableListMultimap replicatedSources = ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 0, 1)); taskSource = createSourceDistributionTaskSource( ImmutableList.of(split1, split2, split3), replicatedSources, @@ -805,7 +797,7 @@ PLAN_NODE_1, new TestingSplitSource(TEST_CATALOG_HANDLE, splitsFuture1, 0), PLAN_NODE_2, new TestingSplitSource(TEST_CATALOG_HANDLE, splitsFuture2, 0)), ImmutableListMultimap.of(), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1)), 1, createPartitioningScheme(4, 4), 0, @@ -954,100 +946,4 @@ private static ListMultimap extractCatalogSplits(ListMultimap }); return result.build(); } - - private static class TestingConnectorSplit - implements ConnectorSplit - { - private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(TestingConnectorSplit.class).instanceSize()); - - private final int id; - private final OptionalInt bucket; - private final Optional> addresses; - private final SplitWeight weight; - - public TestingConnectorSplit(int id, OptionalInt bucket, Optional> addresses) - { - this(id, bucket, addresses, SplitWeight.standard().getRawValue()); - } - - public TestingConnectorSplit(int id, OptionalInt bucket, Optional> addresses, long weight) - { - this.id = id; - this.bucket = requireNonNull(bucket, "bucket is null"); - this.addresses = addresses.map(ImmutableList::copyOf); - this.weight = SplitWeight.fromRawValue(weight); - } - - public int getId() - { - return id; - } - - public OptionalInt getBucket() - { - return bucket; - } - - @Override - public boolean isRemotelyAccessible() - { - return addresses.isEmpty(); - } - - @Override - public List getAddresses() - { - return addresses.orElse(ImmutableList.of()); - } - - @Override - public SplitWeight getSplitWeight() - { - return weight; - } - - @Override - public Object getInfo() - { - return null; - } - - @Override - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE - + sizeOf(bucket) - + sizeOf(addresses, value -> estimatedSizeOf(value, HostAddress::getRetainedSizeInBytes)); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TestingConnectorSplit that = (TestingConnectorSplit) o; - return id == that.id && weight == that.weight && Objects.equals(bucket, that.bucket) && Objects.equals(addresses, that.addresses); - } - - @Override - public int hashCode() - { - return Objects.hash(id, bucket, addresses, weight); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("id", id) - .add("bucket", bucket) - .add("addresses", addresses) - .add("weight", weight) - .toString(); - } - } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingConnectorSplit.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingConnectorSplit.java new file mode 100644 index 000000000000..15f80164b634 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingConnectorSplit.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; +import io.trino.spi.connector.ConnectorSplit; +import org.openjdk.jol.info.ClassLayout; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.util.Objects.requireNonNull; + +class TestingConnectorSplit + implements ConnectorSplit +{ + private static final long INSTANCE_SIZE = ClassLayout.parseClass(TestingConnectorSplit.class).instanceSize(); + + private final int id; + private final OptionalInt bucket; + private final Optional> addresses; + private final SplitWeight weight; + + public TestingConnectorSplit(int id, OptionalInt bucket, Optional> addresses) + { + this(id, bucket, addresses, SplitWeight.standard().getRawValue()); + } + + public TestingConnectorSplit(int id, OptionalInt bucket, Optional> addresses, long weight) + { + this.id = id; + this.bucket = requireNonNull(bucket, "bucket is null"); + this.addresses = addresses.map(ImmutableList::copyOf); + this.weight = SplitWeight.fromRawValue(weight); + } + + public int getId() + { + return id; + } + + public OptionalInt getBucket() + { + return bucket; + } + + @Override + public boolean isRemotelyAccessible() + { + return addresses.isEmpty(); + } + + @Override + public List getAddresses() + { + return addresses.orElse(ImmutableList.of()); + } + + @Override + public SplitWeight getSplitWeight() + { + return weight; + } + + @Override + public Object getInfo() + { + return null; + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + + sizeOf(bucket) + + sizeOf(addresses, value -> estimatedSizeOf(value, HostAddress::getRetainedSizeInBytes)); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingConnectorSplit that = (TestingConnectorSplit) o; + return id == that.id && weight == that.weight && Objects.equals(bucket, that.bucket) && Objects.equals(addresses, that.addresses); + } + + @Override + public int hashCode() + { + return Objects.hash(id, bucket, addresses, weight); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", id) + .add("bucket", bucket) + .add("addresses", addresses) + .add("weight", weight) + .toString(); + } + + public static int getSplitId(Split split) + { + return ((TestingConnectorSplit) split.getConnectorSplit()).getId(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java index 6c260ef06d36..c5bdaaf61f2a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java @@ -21,7 +21,6 @@ import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.spi.exchange.ExchangeSourceHandle; import io.trino.spi.exchange.ExchangeSourceHandleSource; -import org.openjdk.jol.info.ClassLayout; import java.util.List; import java.util.Objects; @@ -32,7 +31,6 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.Sets.newConcurrentHashSet; import static io.trino.spi.exchange.ExchangeId.createRandomExchangeId; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class TestingExchange @@ -153,72 +151,6 @@ public int getAttemptId() } } - public static class TestingExchangeSourceHandle - implements ExchangeSourceHandle - { - private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(TestingExchangeSourceHandle.class).instanceSize()); - - private final int partitionId; - private final long sizeInBytes; - - public TestingExchangeSourceHandle(int partitionId, long sizeInBytes) - { - this.partitionId = partitionId; - this.sizeInBytes = sizeInBytes; - } - - @Override - public int getPartitionId() - { - return partitionId; - } - - @Override - public long getDataSizeInBytes() - { - return sizeInBytes; - } - - @Override - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE; - } - - public long getSizeInBytes() - { - return sizeInBytes; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TestingExchangeSourceHandle that = (TestingExchangeSourceHandle) o; - return partitionId == that.partitionId && sizeInBytes == that.sizeInBytes; - } - - @Override - public int hashCode() - { - return Objects.hash(partitionId, sizeInBytes); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("partitionId", partitionId) - .add("sizeInBytes", sizeInBytes) - .toString(); - } - } - public static class TestingExchangeSinkHandle implements ExchangeSinkHandle { diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeSourceHandle.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeSourceHandle.java new file mode 100644 index 000000000000..20ece535f2c3 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeSourceHandle.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.spi.exchange.ExchangeSourceHandle; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class TestingExchangeSourceHandle + implements ExchangeSourceHandle +{ + private static final long INSTANCE_SIZE = ClassLayout.parseClass(TestingExchangeSourceHandle.class).instanceSize(); + + private final int id; + private final int partitionId; + private final long sizeInBytes; + + public TestingExchangeSourceHandle(int id, int partitionId, long sizeInBytes) + { + this.id = id; + this.partitionId = partitionId; + this.sizeInBytes = sizeInBytes; + } + + public int getId() + { + return id; + } + + @Override + public int getPartitionId() + { + return partitionId; + } + + @Override + public long getDataSizeInBytes() + { + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingExchangeSourceHandle that = (TestingExchangeSourceHandle) o; + return id == that.id && partitionId == that.partitionId && sizeInBytes == that.sizeInBytes; + } + + @Override + public int hashCode() + { + return Objects.hash(id, partitionId, sizeInBytes); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", id) + .add("partitionId", partitionId) + .add("sizeInBytes", sizeInBytes) + .toString(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceCallback.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceCallback.java new file mode 100644 index 000000000000..eee3324318f3 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceCallback.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.SetMultimap; +import com.google.common.collect.Sets; +import com.google.common.primitives.ImmutableIntArray; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.common.util.concurrent.UncheckedExecutionException; +import io.trino.metadata.Split; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static org.assertj.core.api.Assertions.assertThat; + +class TestingTaskSourceCallback + implements EventDrivenTaskSource.Callback +{ + private final Map nodeRequirements = new HashMap<>(); + private final Map> splits = new HashMap<>(); + private final SetMultimap noMoreSplits = HashMultimap.create(); + private final Set sealedPartitions = new HashSet<>(); + private boolean noMorePartitions; + private final SettableFuture> taskDescriptors = SettableFuture.create(); + + public ListenableFuture> getTaskDescriptorsFuture() + { + return taskDescriptors; + } + + public List getTaskDescriptors() + { + try { + return Futures.getDone(taskDescriptors); + } + catch (ExecutionException e) { + throw new UncheckedExecutionException(e); + } + } + + public synchronized int getPartitionCount() + { + return nodeRequirements.size(); + } + + public synchronized NodeRequirements getNodeRequirements(int partition) + { + NodeRequirements result = nodeRequirements.get(partition); + checkArgument(result != null, "partition not found: %s", partition); + return result; + } + + public synchronized Set getSplitIds(int partition, PlanNodeId planNodeId) + { + ListMultimap partitionSplits = splits.getOrDefault(partition, ImmutableListMultimap.of()); + return partitionSplits.get(planNodeId).stream() + .map(split -> (TestingConnectorSplit) split.getConnectorSplit()) + .map(TestingConnectorSplit::getId) + .collect(toImmutableSet()); + } + + public synchronized boolean isNoMoreSplits(int partition, PlanNodeId planNodeId) + { + return noMoreSplits.get(partition).contains(planNodeId); + } + + public synchronized boolean isSealed(int partition) + { + return sealedPartitions.contains(partition); + } + + public synchronized boolean isNoMorePartitions() + { + return noMorePartitions; + } + + public void checkContainsSplits(PlanNodeId planNodeId, Collection splits, boolean replicated) + { + Set expectedSplitIds = splits.stream() + .map(TestingConnectorSplit::getSplitId) + .collect(Collectors.toSet()); + for (int partitionId = 0; partitionId < getPartitionCount(); partitionId++) { + Set partitionSplitIds = getSplitIds(partitionId, planNodeId); + if (replicated) { + assertThat(partitionSplitIds).containsAll(expectedSplitIds); + } + else { + expectedSplitIds.removeAll(partitionSplitIds); + } + } + if (!replicated) { + assertThat(expectedSplitIds).isEmpty(); + } + } + + @Override + public synchronized void partitionsAdded(List partitions) + { + verify(!noMorePartitions, "noMorePartitions is set"); + for (EventDrivenTaskSource.Partition partition : partitions) { + verify(nodeRequirements.put(partition.partitionId(), partition.nodeRequirements()) == null, "partition already exist: %s", partition.partitionId()); + } + } + + @Override + public synchronized void noMorePartitions() + { + noMorePartitions = true; + checkFinished(); + } + + @Override + public synchronized void partitionsUpdated(List partitionUpdates) + { + for (EventDrivenTaskSource.PartitionUpdate partitionUpdate : partitionUpdates) { + int partitionId = partitionUpdate.partitionId(); + verify(nodeRequirements.get(partitionId) != null, "partition does not exist: %s", partitionId); + verify(!sealedPartitions.contains(partitionId), "partition is sealed: %s", partitionId); + PlanNodeId planNodeId = partitionUpdate.planNodeId(); + if (!partitionUpdate.splits().isEmpty()) { + verify(!noMoreSplits.get(partitionId).contains(planNodeId), "noMoreSplits is set for partition %s and plan node %s", partitionId, planNodeId); + splits.computeIfAbsent(partitionId, (key) -> ArrayListMultimap.create()).putAll(planNodeId, partitionUpdate.splits()); + } + if (partitionUpdate.noMoreSplits()) { + noMoreSplits.put(partitionId, planNodeId); + } + } + } + + @Override + public synchronized void partitionsSealed(ImmutableIntArray partitionIds) + { + partitionIds.forEach(sealedPartitions::add); + checkFinished(); + } + + private synchronized void checkFinished() + { + if (noMorePartitions && sealedPartitions.containsAll(nodeRequirements.keySet())) { + verify(sealedPartitions.equals(nodeRequirements.keySet()), "unknown sealed partitions: %s", Sets.difference(sealedPartitions, nodeRequirements.keySet())); + ImmutableList.Builder result = ImmutableList.builder(); + for (Integer partitionId : sealedPartitions) { + ListMultimap taskSplits = splits.getOrDefault(partitionId, ImmutableListMultimap.of()); + verify( + noMoreSplits.get(partitionId).containsAll(taskSplits.keySet()), + "no more split is missing for partition %s: %s", + partitionId, + Sets.difference(taskSplits.keySet(), noMoreSplits.get(partitionId))); + result.add(new TaskDescriptor( + partitionId, + taskSplits, + nodeRequirements.get(partitionId))); + } + taskDescriptors.set(result.build()); + } + } + + @Override + public synchronized void failed(Throwable t) + { + taskDescriptors.setException(t); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java index f8e4837beec1..3325388bbced 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceOutputSelector.java @@ -142,6 +142,21 @@ public void checkValidTransition(ExchangeSourceOutputSelector newSelector) } } + public ExchangeSourceOutputSelector merge(ExchangeSourceOutputSelector other) + { + Map values = new HashMap<>(this.values); + other.values.forEach((exchangeId, value) -> { + Slice currentValue = values.putIfAbsent(exchangeId, value); + if (currentValue != null) { + throw new IllegalArgumentException("duplicated selector for exchange: " + exchangeId); + } + }); + return new ExchangeSourceOutputSelector( + this.version + other.version, + values, + this.finalSelector && other.finalSelector); + } + private int getPartitionCount(ExchangeId exchangeId) { Slice values = this.values.get(exchangeId);