diff --git a/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java b/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java index a3f77279a155..edb1141e1ac1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java +++ b/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java @@ -24,7 +24,8 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.IntConsumer; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @@ -47,9 +48,9 @@ public void addTask(InternalNode node, RemoteTask task) createOrGetNodeTasks(node).addTask(task); } - public int getPartitionedSplitsOnNode(InternalNode node) + public PartitionedSplitsInfo getPartitionedSplitsOnNode(InternalNode node) { - return createOrGetNodeTasks(node).getPartitionedSplitCount(); + return createOrGetNodeTasks(node).getPartitionedSplitsInfo(); } public PartitionedSplitCountTracker createPartitionedSplitCountTracker(InternalNode node, TaskId taskId) @@ -80,6 +81,7 @@ private static class NodeTasks { private final Set remoteTasks = Sets.newConcurrentHashSet(); private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger(); + private final AtomicLong nodeTotalPartitionedSplitWeight = new AtomicLong(); private final FinalizerService finalizerService; public NodeTasks(FinalizerService finalizerService) @@ -87,9 +89,9 @@ public NodeTasks(FinalizerService finalizerService) this.finalizerService = requireNonNull(finalizerService, "finalizerService is null"); } - private int getPartitionedSplitCount() + private PartitionedSplitsInfo getPartitionedSplitsInfo() { - return nodeTotalPartitionedSplitCount.get(); + return PartitionedSplitsInfo.forSplitCountAndWeightSum(nodeTotalPartitionedSplitCount.get(), nodeTotalPartitionedSplitWeight.get()); } private void addTask(RemoteTask task) @@ -112,8 +114,8 @@ public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId ta { requireNonNull(taskId, "taskId is null"); - TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId); - PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker::setPartitionedSplitCount); + TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId, nodeTotalPartitionedSplitCount, nodeTotalPartitionedSplitWeight); + PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker); // when partitionedSplitCountTracker is garbage collected, run the cleanup method on the tracker // Note: tracker cannot have a reference to partitionedSplitCountTracker @@ -123,41 +125,66 @@ public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId ta } @ThreadSafe - private class TaskPartitionedSplitCountTracker + private static class TaskPartitionedSplitCountTracker + implements Consumer { private final TaskId taskId; + private final AtomicInteger nodeTotalPartitionedSplitCount; + private final AtomicLong nodeTotalPartitionedSplitWeight; private final AtomicInteger localPartitionedSplitCount = new AtomicInteger(); + private final AtomicLong localPartitionedSplitWeight = new AtomicLong(); - public TaskPartitionedSplitCountTracker(TaskId taskId) + public TaskPartitionedSplitCountTracker(TaskId taskId, AtomicInteger nodeTotalPartitionedSplitCount, AtomicLong nodeTotalPartitionedSplitWeight) { this.taskId = requireNonNull(taskId, "taskId is null"); + this.nodeTotalPartitionedSplitCount = requireNonNull(nodeTotalPartitionedSplitCount, "nodeTotalPartitionedSplitCount is null"); + this.nodeTotalPartitionedSplitWeight = requireNonNull(nodeTotalPartitionedSplitWeight, "nodeTotalPartitionedSplitWeight is null"); } - public synchronized void setPartitionedSplitCount(int partitionedSplitCount) + @Override + public synchronized void accept(PartitionedSplitsInfo partitionedSplits) { - if (partitionedSplitCount < 0) { - int oldValue = localPartitionedSplitCount.getAndSet(0); - nodeTotalPartitionedSplitCount.addAndGet(-oldValue); - throw new IllegalArgumentException("partitionedSplitCount is negative"); + if (partitionedSplits == null || partitionedSplits.getCount() < 0 || partitionedSplits.getWeightSum() < 0) { + clearLocalSplitInfo(false); + requireNonNull(partitionedSplits, "partitionedSplits is null"); // throw NPE if null, otherwise negative value + throw new IllegalArgumentException("Invalid negative value: " + partitionedSplits); } - int oldValue = localPartitionedSplitCount.getAndSet(partitionedSplitCount); - nodeTotalPartitionedSplitCount.addAndGet(partitionedSplitCount - oldValue); + int newCount = partitionedSplits.getCount(); + long newWeight = partitionedSplits.getWeightSum(); + int countDelta = newCount - localPartitionedSplitCount.getAndSet(newCount); + long weightDelta = newWeight - localPartitionedSplitWeight.getAndSet(newWeight); + if (countDelta != 0) { + nodeTotalPartitionedSplitCount.addAndGet(countDelta); + } + if (weightDelta != 0) { + nodeTotalPartitionedSplitWeight.addAndGet(weightDelta); + } } - public void cleanup() + private void clearLocalSplitInfo(boolean reportAsLeaked) { - int leakedSplits = localPartitionedSplitCount.getAndSet(0); - if (leakedSplits == 0) { + int leakedCount = localPartitionedSplitCount.getAndSet(0); + long leakedWeight = localPartitionedSplitWeight.getAndSet(0); + if (leakedCount == 0 && leakedWeight == 0) { return; } - log.error("BUG! %s for %s leaked with %s partitioned splits. Cleaning up so server can continue to function.", - getClass().getName(), - taskId, - leakedSplits); + if (reportAsLeaked) { + log.error("BUG! %s for %s leaked with %s partitioned splits (weight: %s). Cleaning up so server can continue to function.", + getClass().getName(), + taskId, + leakedCount, + leakedWeight); + } - nodeTotalPartitionedSplitCount.addAndGet(-leakedSplits); + nodeTotalPartitionedSplitCount.addAndGet(-leakedCount); + nodeTotalPartitionedSplitWeight.addAndGet(-leakedWeight); + } + + public void cleanup() + { + clearLocalSplitInfo(true); } @Override @@ -166,6 +193,7 @@ public String toString() return toStringHelper(this) .add("taskId", taskId) .add("splits", localPartitionedSplitCount) + .add("weight", localPartitionedSplitWeight) .toString(); } } @@ -173,16 +201,16 @@ public String toString() public static class PartitionedSplitCountTracker { - private final IntConsumer splitSetter; + private final Consumer splitSetter; - public PartitionedSplitCountTracker(IntConsumer splitSetter) + public PartitionedSplitCountTracker(Consumer splitSetter) { this.splitSetter = requireNonNull(splitSetter, "splitSetter is null"); } - public void setPartitionedSplitCount(int partitionedSplitCount) + public void setPartitionedSplits(PartitionedSplitsInfo partitionedSplits) { - splitSetter.accept(partitionedSplitCount); + splitSetter.accept(partitionedSplits); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/PartitionedSplitsInfo.java b/core/trino-main/src/main/java/io/trino/execution/PartitionedSplitsInfo.java new file mode 100644 index 000000000000..ec7df6f66156 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/PartitionedSplitsInfo.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public final class PartitionedSplitsInfo +{ + private static final PartitionedSplitsInfo NO_SPLITS_INFO = new PartitionedSplitsInfo(0, 0); + + private final int count; + private final long weightSum; + + private PartitionedSplitsInfo(int splitCount, long splitsWeightSum) + { + this.count = splitCount; + this.weightSum = splitsWeightSum; + } + + public int getCount() + { + return count; + } + + public long getWeightSum() + { + return weightSum; + } + + @Override + public int hashCode() + { + return (count * 31) + Long.hashCode(weightSum); + } + + @Override + public boolean equals(Object other) + { + if (!(other instanceof PartitionedSplitsInfo)) { + return false; + } + PartitionedSplitsInfo otherInfo = (PartitionedSplitsInfo) other; + return this == otherInfo || (this.count == otherInfo.count && this.weightSum == otherInfo.weightSum); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("count", count) + .add("weightSum", weightSum) + .toString(); + } + + public static PartitionedSplitsInfo forSplitCountAndWeightSum(int splitCount, long weightSum) + { + // Avoid allocating for the "no splits" case, also mask potential race condition between + // count and weight updates that might yield a positive weight with a count of 0 + return splitCount == 0 ? NO_SPLITS_INFO : new PartitionedSplitsInfo(splitCount, weightSum); + } + + public static PartitionedSplitsInfo forZeroSplits() + { + return NO_SPLITS_INFO; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java index 03ceebc275b6..3320faf99c44 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java @@ -55,15 +55,15 @@ public interface RemoteTask */ void addFinalTaskInfoListener(StateChangeListener stateChangeListener); - ListenableFuture whenSplitQueueHasSpace(int threshold); + ListenableFuture whenSplitQueueHasSpace(long weightThreshold); void cancel(); void abort(); - int getPartitionedSplitCount(); + PartitionedSplitsInfo getPartitionedSplitsInfo(); - int getQueuedPartitionedSplitCount(); + PartitionedSplitsInfo getQueuedPartitionedSplitsInfo(); int getUnacknowledgedPartitionedSplitCount(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index cc6e9363b6a6..6463190c42e3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -280,7 +280,9 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder) } int queuedPartitionedDrivers = 0; + long queuedPartitionedSplitsWeight = 0L; int runningPartitionedDrivers = 0; + long runningPartitionedSplitsWeight = 0L; DataSize physicalWrittenDataSize = DataSize.ofBytes(0); DataSize userMemoryReservation = DataSize.ofBytes(0); DataSize systemMemoryReservation = DataSize.ofBytes(0); @@ -294,7 +296,9 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder) TaskInfo taskInfo = taskHolder.getFinalTaskInfo(); TaskStats taskStats = taskInfo.getStats(); queuedPartitionedDrivers = taskStats.getQueuedPartitionedDrivers(); + queuedPartitionedSplitsWeight = taskStats.getQueuedPartitionedSplitsWeight(); runningPartitionedDrivers = taskStats.getRunningPartitionedDrivers(); + runningPartitionedSplitsWeight = taskStats.getRunningPartitionedSplitsWeight(); physicalWrittenDataSize = taskStats.getPhysicalWrittenDataSize(); userMemoryReservation = taskStats.getUserMemoryReservation(); systemMemoryReservation = taskStats.getSystemMemoryReservation(); @@ -308,7 +312,9 @@ else if (taskHolder.getTaskExecution() != null) { for (PipelineContext pipelineContext : taskContext.getPipelineContexts()) { PipelineStatus pipelineStatus = pipelineContext.getPipelineStatus(); queuedPartitionedDrivers += pipelineStatus.getQueuedPartitionedDrivers(); + queuedPartitionedSplitsWeight += pipelineStatus.getQueuedPartitionedSplitsWeight(); runningPartitionedDrivers += pipelineStatus.getRunningPartitionedDrivers(); + runningPartitionedSplitsWeight += pipelineStatus.getRunningPartitionedSplitsWeight(); physicalWrittenBytes += pipelineContext.getPhysicalWrittenDataSize(); } physicalWrittenDataSize = succinctBytes(physicalWrittenBytes); @@ -338,7 +344,9 @@ else if (taskHolder.getTaskExecution() != null) { revocableMemoryReservation, fullGcCount, fullGcTime, - dynamicFiltersVersion); + dynamicFiltersVersion, + queuedPartitionedSplitsWeight, + runningPartitionedSplitsWeight); } private TaskStats getTaskStats(TaskHolder taskHolder) diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 5eef0dea2fa6..a8ff84d44ceb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -37,6 +37,7 @@ import io.trino.operator.PipelineExecutionStrategy; import io.trino.operator.StageExecutionDescriptor; import io.trino.operator.TaskContext; +import io.trino.spi.SplitWeight; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.plan.PlanNodeId; @@ -371,7 +372,7 @@ private void mergeIntoPendingSplits(PlanNodeId planNodeId, Set s DriverSplitRunnerFactory partitionedDriverFactory = driverRunnerFactoriesWithSplitLifeCycle.get(planNodeId); PendingSplitsForPlanNode pendingSplitsForPlanNode = pendingSplitsByPlanNode.get(planNodeId); - partitionedDriverFactory.splitsAdded(scheduledSplits.size()); + partitionedDriverFactory.splitsAdded(scheduledSplits.size(), SplitWeight.rawValueSum(scheduledSplits, scheduledSplit -> scheduledSplit.getSplit().getSplitWeight())); for (ScheduledSplit scheduledSplit : scheduledSplits) { Lifespan lifespan = scheduledSplit.getSplit().getLifespan(); checkLifespan(partitionedDriverFactory.getPipelineExecutionStrategy(), lifespan); @@ -933,7 +934,8 @@ public DriverSplitRunner createDriverRunner(@Nullable ScheduledSplit partitioned status.incrementPendingCreation(pipelineContext.getPipelineId(), lifespan); // create driver context immediately so the driver existence is recorded in the stats // the number of drivers is used to balance work across nodes - DriverContext driverContext = pipelineContext.addDriverContext(lifespan); + long splitWeight = partitionedSplit == null ? 0 : partitionedSplit.getSplit().getSplitWeight().getRawValue(); + DriverContext driverContext = pipelineContext.addDriverContext(lifespan, splitWeight); return new DriverSplitRunner(this, driverContext, partitionedSplit, lifespan); } @@ -1003,9 +1005,9 @@ public OptionalInt getDriverInstances() return driverFactory.getDriverInstances(); } - public void splitsAdded(int count) + public void splitsAdded(int count, long weightSum) { - pipelineContext.splitsAdded(count); + pipelineContext.splitsAdded(count, weightSum); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java b/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java index 797bfdefe472..dfcd4136aac9 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java @@ -55,7 +55,9 @@ public class TaskStatus private final Set completedDriverGroups; private final int queuedPartitionedDrivers; + private final long queuedPartitionedSplitsWeight; private final int runningPartitionedDrivers; + private final long runningPartitionedSplitsWeight; private final boolean outputBufferOverutilized; private final DataSize physicalWrittenDataSize; private final DataSize memoryReservation; @@ -88,7 +90,9 @@ public TaskStatus( @JsonProperty("revocableMemoryReservation") DataSize revocableMemoryReservation, @JsonProperty("fullGcCount") long fullGcCount, @JsonProperty("fullGcTime") Duration fullGcTime, - @JsonProperty("dynamicFiltersVersion") long dynamicFiltersVersion) + @JsonProperty("dynamicFiltersVersion") long dynamicFiltersVersion, + @JsonProperty("queuedPartitionedSplitsWeight") long queuedPartitionedSplitsWeight, + @JsonProperty("runningPartitionedSplitsWeight") long runningPartitionedSplitsWeight) { this.taskId = requireNonNull(taskId, "taskId is null"); this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); @@ -102,9 +106,13 @@ public TaskStatus( checkArgument(queuedPartitionedDrivers >= 0, "queuedPartitionedDrivers must be positive"); this.queuedPartitionedDrivers = queuedPartitionedDrivers; + checkArgument(queuedPartitionedSplitsWeight >= 0, "queuedPartitionedSplitsWeight must be positive"); + this.queuedPartitionedSplitsWeight = queuedPartitionedSplitsWeight; checkArgument(runningPartitionedDrivers >= 0, "runningPartitionedDrivers must be positive"); this.runningPartitionedDrivers = runningPartitionedDrivers; + checkArgument(runningPartitionedSplitsWeight >= 0, "runningPartitionedSplitsWeight must be positive"); + this.runningPartitionedSplitsWeight = runningPartitionedSplitsWeight; this.outputBufferOverutilized = outputBufferOverutilized; @@ -230,6 +238,18 @@ public long getDynamicFiltersVersion() return dynamicFiltersVersion; } + @JsonProperty + public long getQueuedPartitionedSplitsWeight() + { + return queuedPartitionedSplitsWeight; + } + + @JsonProperty + public long getRunningPartitionedSplitsWeight() + { + return runningPartitionedSplitsWeight; + } + @Override public String toString() { @@ -259,7 +279,9 @@ public static TaskStatus initialTaskStatus(TaskId taskId, URI location, String n DataSize.ofBytes(0), 0, new Duration(0, MILLISECONDS), - INITIAL_DYNAMIC_FILTERS_VERSION); + INITIAL_DYNAMIC_FILTERS_VERSION, + 0L, + 0L); } public static TaskStatus failWith(TaskStatus taskStatus, TaskState state, List exceptions) @@ -282,6 +304,8 @@ public static TaskStatus failWith(TaskStatus taskStatus, TaskState state, List nodeTotalSplitCount; + private final Map nodeTotalSplitsInfo; private final Map stageQueuedSplitInfo; public NodeAssignmentStats(NodeTaskMap nodeTaskMap, NodeMap nodeMap, List existingTasks) { this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); int nodeMapSize = requireNonNull(nodeMap, "nodeMap is null").getNodesByHostAndPort().size(); - this.nodeTotalSplitCount = new HashMap<>(nodeMapSize); + this.nodeTotalSplitsInfo = new HashMap<>(nodeMapSize); this.stageQueuedSplitInfo = new HashMap<>(nodeMapSize); for (RemoteTask task : existingTasks) { - checkArgument(stageQueuedSplitInfo.put(task.getNodeId(), new PendingSplitInfo(task.getQueuedPartitionedSplitCount(), task.getUnacknowledgedPartitionedSplitCount())) == null, "A single stage may not have multiple tasks running on the same node"); + checkArgument(stageQueuedSplitInfo.put(task.getNodeId(), new PendingSplitInfo(task.getQueuedPartitionedSplitsInfo(), task.getUnacknowledgedPartitionedSplitCount())) == null, "A single stage may not have multiple tasks running on the same node"); } // pre-populate the assignment counts with zeros if (existingTasks.size() < nodeMapSize) { - Function createEmptySplitInfo = (ignored) -> new PendingSplitInfo(0, 0); + Function createEmptySplitInfo = (ignored) -> new PendingSplitInfo(PartitionedSplitsInfo.forZeroSplits(), 0); for (InternalNode node : nodeMap.getNodesByHostAndPort().values()) { stageQueuedSplitInfo.computeIfAbsent(node.getNodeIdentifier(), createEmptySplitInfo); } } } - public int getTotalSplitCount(InternalNode node) + public long getTotalSplitsWeight(InternalNode node) { - int nodeTotalSplits = nodeTotalSplitCount.computeIfAbsent(node, nodeTaskMap::getPartitionedSplitsOnNode); + PartitionedSplitsInfo nodeTotalSplits = nodeTotalSplitsInfo.computeIfAbsent(node, nodeTaskMap::getPartitionedSplitsOnNode); PendingSplitInfo stageInfo = stageQueuedSplitInfo.get(node.getNodeIdentifier()); - return nodeTotalSplits + (stageInfo == null ? 0 : stageInfo.getAssignedSplitCount()); + if (stageInfo == null) { + return nodeTotalSplits.getWeightSum(); + } + return addExact(nodeTotalSplits.getWeightSum(), stageInfo.getAssignedSplitsWeight()); } - public int getQueuedSplitCountForStage(InternalNode node) + public long getQueuedSplitsWeightForStage(InternalNode node) { PendingSplitInfo stageInfo = stageQueuedSplitInfo.get(node.getNodeIdentifier()); - return stageInfo == null ? 0 : stageInfo.getQueuedSplitCount(); + return stageInfo == null ? 0 : stageInfo.getQueuedSplitsWeight(); } public int getUnacknowledgedSplitCountForStage(InternalNode node) @@ -70,14 +76,14 @@ public int getUnacknowledgedSplitCountForStage(InternalNode node) return stageInfo == null ? 0 : stageInfo.getUnacknowledgedSplitCount(); } - public void addAssignedSplit(InternalNode node) + public void addAssignedSplit(InternalNode node, SplitWeight splitWeight) { - getOrCreateStageSplitInfo(node).addAssignedSplit(); + getOrCreateStageSplitInfo(node).addAssignedSplit(splitWeight); } - public void removeAssignedSplit(InternalNode node) + public void removeAssignedSplit(InternalNode node, SplitWeight splitWeight) { - getOrCreateStageSplitInfo(node).removeAssignedSplit(); + getOrCreateStageSplitInfo(node).removeAssignedSplit(splitWeight); } private PendingSplitInfo getOrCreateStageSplitInfo(InternalNode node) @@ -86,7 +92,7 @@ private PendingSplitInfo getOrCreateStageSplitInfo(InternalNode node) // Avoids the extra per-invocation lambda allocation of computeIfAbsent since assigning a split to an existing task more common than creating a new task PendingSplitInfo stageInfo = stageQueuedSplitInfo.get(nodeId); if (stageInfo == null) { - stageInfo = new PendingSplitInfo(0, 0); + stageInfo = new PendingSplitInfo(PartitionedSplitsInfo.forZeroSplits(), 0); stageQueuedSplitInfo.put(nodeId, stageInfo); } return stageInfo; @@ -95,12 +101,15 @@ private PendingSplitInfo getOrCreateStageSplitInfo(InternalNode node) private static final class PendingSplitInfo { private final int queuedSplitCount; + private final long queuedSplitsWeight; private final int unacknowledgedSplitCount; private int assignedSplits; + private long assignedSplitsWeight; - private PendingSplitInfo(int queuedSplitCount, int unacknowledgedSplitCount) + private PendingSplitInfo(PartitionedSplitsInfo queuedSplitsInfo, int unacknowledgedSplitCount) { - this.queuedSplitCount = queuedSplitCount; + this.queuedSplitCount = requireNonNull(queuedSplitsInfo, "queuedSplitsInfo is null").getCount(); + this.queuedSplitsWeight = queuedSplitsInfo.getWeightSum(); this.unacknowledgedSplitCount = unacknowledgedSplitCount; } @@ -109,24 +118,36 @@ public int getAssignedSplitCount() return assignedSplits; } + public long getAssignedSplitsWeight() + { + return assignedSplitsWeight; + } + public int getQueuedSplitCount() { return queuedSplitCount + assignedSplits; } + public long getQueuedSplitsWeight() + { + return addExact(queuedSplitsWeight, assignedSplitsWeight); + } + public int getUnacknowledgedSplitCount() { return unacknowledgedSplitCount + assignedSplits; } - public void addAssignedSplit() + public void addAssignedSplit(SplitWeight splitWeight) { assignedSplits++; + assignedSplitsWeight = addExact(assignedSplitsWeight, splitWeight.getRawValue()); } - public void removeAssignedSplit() + public void removeAssignedSplit(SplitWeight splitWeight) { assignedSplits--; + assignedSplitsWeight -= splitWeight.getRawValue(); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeScheduler.java index 73a7951962e6..852d07df0ff2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeScheduler.java @@ -26,6 +26,7 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import javax.inject.Inject; @@ -47,6 +48,7 @@ import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.whenAnyCompleteCancelOthers; +import static java.lang.Math.addExact; import static java.util.Objects.requireNonNull; public class NodeScheduler @@ -150,8 +152,8 @@ public static List selectExactNodes(NodeMap nodeMap, List splits, List existingTasks, @@ -164,28 +166,42 @@ public static SplitPlacementResult selectDistributionNodes( for (Split split : splits) { // node placement is forced by the bucket to node map InternalNode node = bucketNodeMap.getAssignedNode(split).get(); + SplitWeight splitWeight = split.getSplitWeight(); // if node is full, don't schedule now, which will push back on the scheduling of splits - if (assignmentStats.getUnacknowledgedSplitCountForStage(node) < maxUnacknowledgedSplitsPerTask && - (assignmentStats.getTotalSplitCount(node) < maxSplitsPerNode || assignmentStats.getQueuedSplitCountForStage(node) < maxPendingSplitsPerTask)) { + if (canAssignSplitToDistributionNode(assignmentStats, node, maxSplitsWeightPerNode, maxPendingSplitsWeightPerTask, maxUnacknowledgedSplitsPerTask, splitWeight)) { assignments.put(node, split); - assignmentStats.addAssignedSplit(node); + assignmentStats.addAssignedSplit(node, splitWeight); } else { blockedNodes.add(node); } } - ListenableFuture blocked = toWhenHasSplitQueueSpaceFuture(blockedNodes, existingTasks, calculateLowWatermark(maxPendingSplitsPerTask)); + ListenableFuture blocked = toWhenHasSplitQueueSpaceFuture(blockedNodes, existingTasks, calculateLowWatermark(maxPendingSplitsWeightPerTask)); return new SplitPlacementResult(blocked, ImmutableMultimap.copyOf(assignments)); } - public static int calculateLowWatermark(int maxPendingSplitsPerTask) + private static boolean canAssignSplitToDistributionNode(NodeAssignmentStats assignmentStats, InternalNode node, long maxSplitsWeightPerNode, long maxPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, SplitWeight splitWeight) { - return (int) Math.ceil(maxPendingSplitsPerTask / 2.0); + return assignmentStats.getUnacknowledgedSplitCountForStage(node) < maxUnacknowledgedSplitsPerTask && + (canAssignSplitBasedOnWeight(assignmentStats.getTotalSplitsWeight(node), maxSplitsWeightPerNode, splitWeight) || + canAssignSplitBasedOnWeight(assignmentStats.getQueuedSplitsWeightForStage(node), maxPendingSplitsWeightPerTask, splitWeight)); } - public static ListenableFuture toWhenHasSplitQueueSpaceFuture(Set blockedNodes, List existingTasks, int spaceThreshold) + public static boolean canAssignSplitBasedOnWeight(long currentWeight, long weightLimit, SplitWeight splitWeight) + { + // Nodes or tasks that are configured to accept any splits (ie: weightLimit > 0) should always accept at least one split when + // empty (ie: currentWeight == 0) to ensure that forward progress can be made if split weights are huge + return addExact(currentWeight, splitWeight.getRawValue()) <= weightLimit || (currentWeight == 0 && weightLimit > 0); + } + + public static long calculateLowWatermark(long maxPendingSplitsWeightPerTask) + { + return (long) Math.ceil(maxPendingSplitsWeightPerTask * 0.5); + } + + public static ListenableFuture toWhenHasSplitQueueSpaceFuture(Set blockedNodes, List existingTasks, long weightSpaceThreshold) { if (blockedNodes.isEmpty()) { return immediateVoidFuture(); @@ -198,7 +214,7 @@ public static ListenableFuture toWhenHasSplitQueueSpaceFuture(Set remoteTask.whenSplitQueueHasSpace(spaceThreshold)) + .map(remoteTask -> remoteTask.whenSplitQueueHasSpace(weightSpaceThreshold)) .collect(toImmutableList()); if (blockedFutures.isEmpty()) { return immediateVoidFuture(); @@ -206,13 +222,13 @@ public static ListenableFuture toWhenHasSplitQueueSpaceFuture(Set toWhenHasSplitQueueSpaceFuture(List existingTasks, int spaceThreshold) + public static ListenableFuture toWhenHasSplitQueueSpaceFuture(List existingTasks, long weightSpaceThreshold) { if (existingTasks.isEmpty()) { return immediateVoidFuture(); } List> stateChangeFutures = existingTasks.stream() - .map(remoteTask -> remoteTask.whenSplitQueueHasSpace(spaceThreshold)) + .map(remoteTask -> remoteTask.whenSplitQueueHasSpace(weightSpaceThreshold)) .collect(toImmutableList()); return asVoid(whenAnyCompleteCancelOthers(stateChangeFutures)); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelector.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelector.java index 3d042cb024a5..df4a259e08c6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelector.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelector.java @@ -25,6 +25,7 @@ import io.trino.metadata.InternalNodeManager; import io.trino.metadata.Split; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.TrinoException; import javax.annotation.Nullable; @@ -39,6 +40,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.execution.scheduler.NetworkLocation.ROOT_LOCATION; import static io.trino.execution.scheduler.NodeScheduler.calculateLowWatermark; +import static io.trino.execution.scheduler.NodeScheduler.canAssignSplitBasedOnWeight; import static io.trino.execution.scheduler.NodeScheduler.getAllNodes; import static io.trino.execution.scheduler.NodeScheduler.randomizedNodes; import static io.trino.execution.scheduler.NodeScheduler.selectDistributionNodes; @@ -58,8 +60,8 @@ public class TopologyAwareNodeSelector private final boolean includeCoordinator; private final AtomicReference> nodeMap; private final int minCandidates; - private final int maxSplitsPerNode; - private final int maxPendingSplitsPerTask; + private final long maxSplitsWeightPerNode; + private final long maxPendingSplitsWeightPerTask; private final int maxUnacknowledgedSplitsPerTask; private final List topologicalSplitCounters; private final NetworkTopology networkTopology; @@ -70,8 +72,8 @@ public TopologyAwareNodeSelector( boolean includeCoordinator, Supplier nodeMap, int minCandidates, - int maxSplitsPerNode, - int maxPendingSplitsPerTask, + long maxSplitsWeightPerNode, + long maxPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, List topologicalSplitCounters, NetworkTopology networkTopology) @@ -81,8 +83,8 @@ public TopologyAwareNodeSelector( this.includeCoordinator = includeCoordinator; this.nodeMap = new AtomicReference<>(nodeMap); this.minCandidates = minCandidates; - this.maxSplitsPerNode = maxSplitsPerNode; - this.maxPendingSplitsPerTask = maxPendingSplitsPerTask; + this.maxSplitsWeightPerNode = maxSplitsWeightPerNode; + this.maxPendingSplitsWeightPerTask = maxPendingSplitsWeightPerTask; this.maxUnacknowledgedSplitsPerTask = maxUnacknowledgedSplitsPerTask; checkArgument(maxUnacknowledgedSplitsPerTask > 0, "maxUnacknowledgedSplitsPerTask must be > 0, found: %s", maxUnacknowledgedSplitsPerTask); this.topologicalSplitCounters = requireNonNull(topologicalSplitCounters, "topologicalSplitCounters is null"); @@ -126,16 +128,17 @@ public SplitPlacementResult computeAssignments(Set splits, List blockedExactNodes = new HashSet<>(); boolean splitWaitingForAnyNode = false; for (Split split : splits) { + SplitWeight splitWeight = split.getSplitWeight(); if (!split.isRemotelyAccessible()) { List candidateNodes = selectExactNodes(nodeMap, split.getAddresses(), includeCoordinator); if (candidateNodes.isEmpty()) { log.debug("No nodes available to schedule %s. Available nodes %s", split, nodeMap.getNodesByHost().keys()); throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); } - InternalNode chosenNode = bestNodeSplitCount(candidateNodes.iterator(), minCandidates, maxPendingSplitsPerTask, assignmentStats); + InternalNode chosenNode = bestNodeSplitCount(splitWeight, candidateNodes.iterator(), minCandidates, maxPendingSplitsWeightPerTask, assignmentStats); if (chosenNode != null) { assignment.put(chosenNode, split); - assignmentStats.addAssignedSplit(chosenNode); + assignmentStats.addAssignedSplit(chosenNode, splitWeight); } // Exact node set won't matter, if a split is waiting for any node else if (!splitWaitingForAnyNode) { @@ -169,7 +172,7 @@ else if (!splitWaitingForAnyNode) { continue; } Set nodes = nodeMap.getWorkersByNetworkPath().get(location); - chosenNode = bestNodeSplitCount(new ResettableRandomizedIterator<>(nodes), minCandidates, calculateMaxPendingSplits(i, depth), assignmentStats); + chosenNode = bestNodeSplitCount(splitWeight, new ResettableRandomizedIterator<>(nodes), minCandidates, calculateMaxPendingSplitsWeightPerTask(i, depth), assignmentStats); if (chosenNode != null) { chosenDepth = i; break; @@ -179,7 +182,7 @@ else if (!splitWaitingForAnyNode) { } if (chosenNode != null) { assignment.put(chosenNode, split); - assignmentStats.addAssignedSplit(chosenNode); + assignmentStats.addAssignedSplit(chosenNode, splitWeight); topologicCounters[chosenDepth]++; } else { @@ -193,7 +196,7 @@ else if (!splitWaitingForAnyNode) { } ListenableFuture blocked; - int maxPendingForWildcardNetworkAffinity = calculateMaxPendingSplits(0, topologicalSplitCounters.size() - 1); + long maxPendingForWildcardNetworkAffinity = calculateMaxPendingSplitsWeightPerTask(0, topologicalSplitCounters.size() - 1); if (splitWaitingForAnyNode) { blocked = toWhenHasSplitQueueSpaceFuture(existingTasks, calculateLowWatermark(maxPendingForWildcardNetworkAffinity)); } @@ -208,40 +211,43 @@ else if (!splitWaitingForAnyNode) { * splitAffinity. A split with zero affinity can only fill half the queue, whereas one that matches * exactly can fill the entire queue. */ - private int calculateMaxPendingSplits(int splitAffinity, int totalDepth) + private long calculateMaxPendingSplitsWeightPerTask(int splitAffinity, int totalDepth) { if (totalDepth == 0) { - return maxPendingSplitsPerTask; + return maxPendingSplitsWeightPerTask; } // Use half the queue for any split // Reserve the other half for splits that have some amount of network affinity double queueFraction = 0.5 * (1.0 + splitAffinity / (double) totalDepth); - return (int) Math.ceil(maxPendingSplitsPerTask * queueFraction); + return (long) Math.ceil(maxPendingSplitsWeightPerTask * queueFraction); } @Override public SplitPlacementResult computeAssignments(Set splits, List existingTasks, BucketNodeMap bucketNodeMap) { - return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsPerNode, maxPendingSplitsPerTask, maxUnacknowledgedSplitsPerTask, splits, existingTasks, bucketNodeMap); + return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsWeightPerNode, maxPendingSplitsWeightPerTask, maxUnacknowledgedSplitsPerTask, splits, existingTasks, bucketNodeMap); } @Nullable - private InternalNode bestNodeSplitCount(Iterator candidates, int minCandidatesWhenFull, int maxPendingSplitsPerTask, NodeAssignmentStats assignmentStats) + private InternalNode bestNodeSplitCount(SplitWeight splitWeight, Iterator candidates, int minCandidatesWhenFull, long maxPendingSplitsWeightPerTask, NodeAssignmentStats assignmentStats) { InternalNode bestQueueNotFull = null; - int min = Integer.MAX_VALUE; + long minWeight = Long.MAX_VALUE; int fullCandidatesConsidered = 0; while (candidates.hasNext() && (fullCandidatesConsidered < minCandidatesWhenFull || bestQueueNotFull == null)) { InternalNode node = candidates.next(); - boolean hasUnacknowledgedSplitSpace = assignmentStats.getUnacknowledgedSplitCountForStage(node) < maxUnacknowledgedSplitsPerTask; - if (hasUnacknowledgedSplitSpace && assignmentStats.getTotalSplitCount(node) < maxSplitsPerNode) { + if (assignmentStats.getUnacknowledgedSplitCountForStage(node) >= maxUnacknowledgedSplitsPerTask) { + fullCandidatesConsidered++; + continue; + } + if (canAssignSplitBasedOnWeight(assignmentStats.getTotalSplitsWeight(node), maxSplitsWeightPerNode, splitWeight)) { return node; } fullCandidatesConsidered++; - int totalSplitCount = assignmentStats.getQueuedSplitCountForStage(node); - if (hasUnacknowledgedSplitSpace && totalSplitCount < min && totalSplitCount < maxPendingSplitsPerTask) { - min = totalSplitCount; + long taskQueuedWeight = assignmentStats.getQueuedSplitsWeightForStage(node); + if (taskQueuedWeight < minWeight && canAssignSplitBasedOnWeight(taskQueuedWeight, maxPendingSplitsWeightPerTask, splitWeight)) { + minWeight = taskQueuedWeight; bestQueueNotFull = node; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java index 2efb8b017314..901dbd4013e7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java @@ -28,6 +28,7 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.InternalNodeManager; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import javax.inject.Inject; @@ -59,8 +60,8 @@ public class TopologyAwareNodeSelectorFactory private final InternalNodeManager nodeManager; private final int minCandidates; private final boolean includeCoordinator; - private final int maxSplitsPerNode; - private final int maxPendingSplitsPerTask; + private final long maxSplitsWeightPerNode; + private final long maxPendingSplitsWeightPerTask; private final NodeTaskMap nodeTaskMap; private final List placementCounters; @@ -84,10 +85,12 @@ public TopologyAwareNodeSelectorFactory( this.nodeManager = nodeManager; this.minCandidates = schedulerConfig.getMinCandidates(); this.includeCoordinator = schedulerConfig.isIncludeCoordinator(); - this.maxSplitsPerNode = schedulerConfig.getMaxSplitsPerNode(); - this.maxPendingSplitsPerTask = schedulerConfig.getMaxPendingSplitsPerTask(); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + int maxSplitsPerNode = schedulerConfig.getMaxSplitsPerNode(); + int maxPendingSplitsPerTask = schedulerConfig.getMaxPendingSplitsPerTask(); checkArgument(maxSplitsPerNode >= maxPendingSplitsPerTask, "maxSplitsPerNode must be > maxPendingSplitsPerTask"); + this.maxSplitsWeightPerNode = SplitWeight.rawValueForStandardSplitCount(maxSplitsPerNode); + this.maxPendingSplitsWeightPerTask = SplitWeight.rawValueForStandardSplitCount(maxPendingSplitsPerTask); Builder placementCounters = ImmutableList.builder(); ImmutableMap.Builder placementCountersByName = ImmutableMap.builder(); @@ -129,8 +132,8 @@ public NodeSelector createNodeSelector(Session session, Optional ca includeCoordinator, nodeMap, minCandidates, - maxSplitsPerNode, - maxPendingSplitsPerTask, + maxSplitsWeightPerNode, + maxPendingSplitsWeightPerTask, getMaxUnacknowledgedSplitsPerTask(session), placementCounters, networkTopology); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelector.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelector.java index 4e315084d898..73733206bbeb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelector.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelector.java @@ -28,6 +28,7 @@ import io.trino.metadata.InternalNodeManager; import io.trino.metadata.Split; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.TrinoException; import java.net.InetAddress; @@ -51,7 +52,7 @@ import static io.trino.execution.scheduler.NodeScheduler.selectNodes; import static io.trino.execution.scheduler.NodeScheduler.toWhenHasSplitQueueSpaceFuture; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; -import static java.util.Comparator.comparingInt; +import static java.util.Comparator.comparingLong; import static java.util.Objects.requireNonNull; public class UniformNodeSelector @@ -64,8 +65,8 @@ public class UniformNodeSelector private final boolean includeCoordinator; private final AtomicReference> nodeMap; private final int minCandidates; - private final int maxSplitsPerNode; - private final int maxPendingSplitsPerTask; + private final long maxSplitsWeightPerNode; + private final long maxPendingSplitsWeightPerTask; private final int maxUnacknowledgedSplitsPerTask; private final boolean optimizedLocalScheduling; @@ -75,8 +76,8 @@ public UniformNodeSelector( boolean includeCoordinator, Supplier nodeMap, int minCandidates, - int maxSplitsPerNode, - int maxPendingSplitsPerTask, + long maxSplitsWeightPerNode, + long maxPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, boolean optimizedLocalScheduling) { @@ -85,8 +86,8 @@ public UniformNodeSelector( this.includeCoordinator = includeCoordinator; this.nodeMap = new AtomicReference<>(nodeMap); this.minCandidates = minCandidates; - this.maxSplitsPerNode = maxSplitsPerNode; - this.maxPendingSplitsPerTask = maxPendingSplitsPerTask; + this.maxSplitsWeightPerNode = maxSplitsWeightPerNode; + this.maxPendingSplitsWeightPerTask = maxPendingSplitsWeightPerTask; this.maxUnacknowledgedSplitsPerTask = maxUnacknowledgedSplitsPerTask; checkArgument(maxUnacknowledgedSplitsPerTask > 0, "maxUnacknowledgedSplitsPerTask must be > 0, found: %s", maxUnacknowledgedSplitsPerTask); this.optimizedLocalScheduling = optimizedLocalScheduling; @@ -138,12 +139,12 @@ public SplitPlacementResult computeAssignments(Set splits, List candidateNodes = selectExactNodes(nodeMap, split.getAddresses(), includeCoordinator); Optional chosenNode = candidateNodes.stream() - .filter(ownerNode -> assignmentStats.getTotalSplitCount(ownerNode) < maxSplitsPerNode && assignmentStats.getUnacknowledgedSplitCountForStage(ownerNode) < maxUnacknowledgedSplitsPerTask) - .min(comparingInt(assignmentStats::getTotalSplitCount)); + .filter(ownerNode -> assignmentStats.getTotalSplitsWeight(ownerNode) < maxSplitsWeightPerNode && assignmentStats.getUnacknowledgedSplitCountForStage(ownerNode) < maxUnacknowledgedSplitsPerTask) + .min(comparingLong(assignmentStats::getTotalSplitsWeight)); if (chosenNode.isPresent()) { assignment.put(chosenNode.get(), split); - assignmentStats.addAssignedSplit(chosenNode.get()); + assignmentStats.addAssignedSplit(chosenNode.get(), split.getSplitWeight()); splitsToBeRedistributed = true; continue; } @@ -171,28 +172,28 @@ public SplitPlacementResult computeAssignments(Set splits, List blocked; if (splitWaitingForAnyNode) { - blocked = toWhenHasSplitQueueSpaceFuture(existingTasks, calculateLowWatermark(maxPendingSplitsPerTask)); + blocked = toWhenHasSplitQueueSpaceFuture(existingTasks, calculateLowWatermark(maxPendingSplitsWeightPerTask)); } else { - blocked = toWhenHasSplitQueueSpaceFuture(blockedExactNodes, existingTasks, calculateLowWatermark(maxPendingSplitsPerTask)); + blocked = toWhenHasSplitQueueSpaceFuture(blockedExactNodes, existingTasks, calculateLowWatermark(maxPendingSplitsWeightPerTask)); } if (splitsToBeRedistributed) { @@ -222,7 +223,7 @@ else if (!splitWaitingForAnyNode) { @Override public SplitPlacementResult computeAssignments(Set splits, List existingTasks, BucketNodeMap bucketNodeMap) { - return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsPerNode, maxPendingSplitsPerTask, maxUnacknowledgedSplitsPerTask, splits, existingTasks, bucketNodeMap); + return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsWeightPerNode, maxPendingSplitsWeightPerTask, maxUnacknowledgedSplitsPerTask, splits, existingTasks, bucketNodeMap); } /** @@ -250,12 +251,12 @@ private void equateDistribution(Multimap assignment, NodeAs IndexedPriorityQueue maxNodes = new IndexedPriorityQueue<>(); for (InternalNode node : assignment.keySet()) { - maxNodes.addOrUpdate(node, assignmentStats.getTotalSplitCount(node)); + maxNodes.addOrUpdate(node, assignmentStats.getTotalSplitsWeight(node)); } IndexedPriorityQueue minNodes = new IndexedPriorityQueue<>(); for (InternalNode node : allNodes) { - minNodes.addOrUpdate(node, Long.MAX_VALUE - assignmentStats.getTotalSplitCount(node)); + minNodes.addOrUpdate(node, Long.MAX_VALUE - assignmentStats.getTotalSplitsWeight(node)); } while (true) { @@ -274,24 +275,24 @@ private void equateDistribution(Multimap assignment, NodeAs // The difference of 5 between node with maximum and minimum splits is a tradeoff between ratio of // misassigned splits and assignment uniformity. Using larger numbers doesn't reduce the number of // misassigned splits greatly (in absolute values). - if (assignmentStats.getTotalSplitCount(maxNode) - assignmentStats.getTotalSplitCount(minNode) <= 5) { + if (assignmentStats.getTotalSplitsWeight(maxNode) - assignmentStats.getTotalSplitsWeight(minNode) <= SplitWeight.rawValueForStandardSplitCount(5)) { return; } // move split from max to min - redistributeSplit(assignment, maxNode, minNode, nodeMap.getNodesByHost()); - assignmentStats.removeAssignedSplit(maxNode); - assignmentStats.addAssignedSplit(minNode); + Split redistributed = redistributeSplit(assignment, maxNode, minNode, nodeMap.getNodesByHost()); + assignmentStats.removeAssignedSplit(maxNode, redistributed.getSplitWeight()); + assignmentStats.addAssignedSplit(minNode, redistributed.getSplitWeight()); // add max back into maxNodes only if it still has assignments if (assignment.containsKey(maxNode)) { - maxNodes.addOrUpdate(maxNode, assignmentStats.getTotalSplitCount(maxNode)); + maxNodes.addOrUpdate(maxNode, assignmentStats.getTotalSplitsWeight(maxNode)); } // Add or update both the Priority Queues with the updated node priorities - maxNodes.addOrUpdate(minNode, assignmentStats.getTotalSplitCount(minNode)); - minNodes.addOrUpdate(minNode, Long.MAX_VALUE - assignmentStats.getTotalSplitCount(minNode)); - minNodes.addOrUpdate(maxNode, Long.MAX_VALUE - assignmentStats.getTotalSplitCount(maxNode)); + maxNodes.addOrUpdate(minNode, assignmentStats.getTotalSplitsWeight(minNode)); + minNodes.addOrUpdate(minNode, Long.MAX_VALUE - assignmentStats.getTotalSplitsWeight(minNode)); + minNodes.addOrUpdate(maxNode, Long.MAX_VALUE - assignmentStats.getTotalSplitsWeight(maxNode)); } } @@ -301,7 +302,7 @@ private void equateDistribution(Multimap assignment, NodeAs * simultaneously. If a Non-local split cannot be found in the maxNode, any split is selected randomly and reassigned. */ @VisibleForTesting - public static void redistributeSplit(Multimap assignment, InternalNode fromNode, InternalNode toNode, SetMultimap nodesByHost) + public static Split redistributeSplit(Multimap assignment, InternalNode fromNode, InternalNode toNode, SetMultimap nodesByHost) { Iterator splitIterator = assignment.get(fromNode).iterator(); Split splitToBeRedistributed = null; @@ -320,6 +321,7 @@ public static void redistributeSplit(Multimap assignment, I } splitIterator.remove(); assignment.put(toNode, splitToBeRedistributed); + return splitToBeRedistributed; } /** diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java index 484239732c62..bfbcaf326692 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java @@ -26,6 +26,7 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.InternalNodeManager; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import javax.inject.Inject; @@ -56,8 +57,8 @@ public class UniformNodeSelectorFactory private final InternalNodeManager nodeManager; private final int minCandidates; private final boolean includeCoordinator; - private final int maxSplitsPerNode; - private final int maxPendingSplitsPerTask; + private final long maxSplitsWeightPerNode; + private final long maxPendingSplitsWeightPerTask; private final boolean optimizedLocalScheduling; private final NodeTaskMap nodeTaskMap; private final Duration nodeMapMemoizationDuration; @@ -85,11 +86,13 @@ public UniformNodeSelectorFactory( this.nodeManager = nodeManager; this.minCandidates = config.getMinCandidates(); this.includeCoordinator = config.isIncludeCoordinator(); - this.maxSplitsPerNode = config.getMaxSplitsPerNode(); - this.maxPendingSplitsPerTask = config.getMaxPendingSplitsPerTask(); this.optimizedLocalScheduling = config.getOptimizedLocalScheduling(); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + int maxSplitsPerNode = config.getMaxSplitsPerNode(); + int maxPendingSplitsPerTask = config.getMaxPendingSplitsPerTask(); checkArgument(maxSplitsPerNode >= maxPendingSplitsPerTask, "maxSplitsPerNode must be > maxPendingSplitsPerTask"); + this.maxSplitsWeightPerNode = SplitWeight.rawValueForStandardSplitCount(maxSplitsPerNode); + this.maxPendingSplitsWeightPerTask = SplitWeight.rawValueForStandardSplitCount(maxPendingSplitsPerTask); this.nodeMapMemoizationDuration = nodeMapMemoizationDuration; } @@ -116,8 +119,8 @@ public NodeSelector createNodeSelector(Session session, Optional ca includeCoordinator, nodeMap, minCandidates, - maxSplitsPerNode, - maxPendingSplitsPerTask, + maxSplitsWeightPerNode, + maxPendingSplitsWeightPerTask, getMaxUnacknowledgedSplitsPerTask(session), optimizedLocalScheduling); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/Split.java b/core/trino-main/src/main/java/io/trino/metadata/Split.java index a3d586fac72f..43695d41d44d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Split.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Split.java @@ -18,6 +18,7 @@ import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.connector.ConnectorSplit; import java.util.List; @@ -75,6 +76,11 @@ public boolean isRemotelyAccessible() return connectorSplit.isRemotelyAccessible(); } + public SplitWeight getSplitWeight() + { + return connectorSplit.getSplitWeight(); + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java index b2cdb2698d73..2ad090628eb3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java @@ -78,13 +78,15 @@ public class DriverContext private final List operatorContexts = new CopyOnWriteArrayList<>(); private final Lifespan lifespan; + private final long splitWeight; public DriverContext( PipelineContext pipelineContext, Executor notificationExecutor, ScheduledExecutorService yieldExecutor, MemoryTrackingContext driverMemoryContext, - Lifespan lifespan) + Lifespan lifespan, + long splitWeight) { this.pipelineContext = requireNonNull(pipelineContext, "pipelineContext is null"); this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); @@ -92,6 +94,8 @@ public DriverContext( this.driverMemoryContext = requireNonNull(driverMemoryContext, "driverMemoryContext is null"); this.lifespan = requireNonNull(lifespan, "lifespan is null"); this.yieldSignal = new DriverYieldSignal(); + this.splitWeight = splitWeight; + checkArgument(splitWeight >= 0, "splitWeight must be >= 0, found: %s", splitWeight); } public TaskId getTaskId() @@ -99,6 +103,11 @@ public TaskId getTaskId() return pipelineContext.getTaskId(); } + public long getSplitWeight() + { + return splitWeight; + } + public OperatorContext addOperatorContext(int operatorId, PlanNodeId planNodeId, String operatorType) { checkArgument(operatorId >= 0, "operatorId is negative"); diff --git a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java index 2a9a09def876..16eddcfc852d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java @@ -69,7 +69,9 @@ public class PipelineContext private final List drivers = new CopyOnWriteArrayList<>(); private final AtomicInteger totalSplits = new AtomicInteger(); + private final AtomicLong totalSplitsWeight = new AtomicLong(); private final AtomicInteger completedDrivers = new AtomicInteger(); + private final AtomicLong completedSplitsWeight = new AtomicLong(); private final AtomicReference executionStartTime = new AtomicReference<>(); private final AtomicReference lastExecutionStartTime = new AtomicReference<>(); @@ -145,17 +147,19 @@ public boolean isOutputPipeline() public DriverContext addDriverContext() { - return addDriverContext(Lifespan.taskWide()); + return addDriverContext(Lifespan.taskWide(), 0); } - public DriverContext addDriverContext(Lifespan lifespan) + public DriverContext addDriverContext(Lifespan lifespan, long splitWeight) { + checkArgument(partitioned || splitWeight == 0, "Only partitioned splits should have weights"); DriverContext driverContext = new DriverContext( this, notificationExecutor, yieldExecutor, pipelineMemoryContext.newMemoryTrackingContext(), - lifespan); + lifespan, + splitWeight); drivers.add(driverContext); return driverContext; } @@ -165,10 +169,13 @@ public Session getSession() return taskContext.getSession(); } - public void splitsAdded(int count) + public void splitsAdded(int count, long weightSum) { - checkArgument(count >= 0); + checkArgument(count >= 0 && weightSum >= 0); totalSplits.addAndGet(count); + if (partitioned && weightSum != 0) { + totalSplitsWeight.addAndGet(weightSum); + } } public void driverFinished(DriverContext driverContext) @@ -185,6 +192,9 @@ public void driverFinished(DriverContext driverContext) DriverStats driverStats = driverContext.getDriverStats(); completedDrivers.getAndIncrement(); + if (partitioned) { + completedSplitsWeight.addAndGet(driverContext.getSplitWeight()); + } queuedTime.add(driverStats.getQueuedTime().roundTo(NANOSECONDS)); elapsedTime.add(driverStats.getElapsedTime().roundTo(NANOSECONDS)); @@ -319,7 +329,15 @@ public long getPhysicalWrittenDataSize() public PipelineStatus getPipelineStatus() { - return getPipelineStatus(drivers.iterator(), totalSplits.get(), completedDrivers.get(), partitioned); + return getPipelineStatus(drivers.iterator(), totalSplits.get(), completedDrivers.get(), getActivePartitionedSplitsWeight(), partitioned); + } + + private long getActivePartitionedSplitsWeight() + { + if (partitioned) { + return totalSplitsWeight.get() - completedSplitsWeight.get(); + } + return 0; } public PipelineStats getPipelineStats() @@ -335,7 +353,7 @@ public PipelineStats getPipelineStats() int completedDrivers = this.completedDrivers.get(); List driverContexts = ImmutableList.copyOf(this.drivers); int totalSplits = this.totalSplits.get(); - PipelineStatus pipelineStatus = getPipelineStatus(driverContexts.iterator(), totalSplits, completedDrivers, partitioned); + PipelineStatus pipelineStatus = getPipelineStatus(driverContexts.iterator(), totalSplits, completedDrivers, getActivePartitionedSplitsWeight(), partitioned); int totalDrivers = completedDrivers + driverContexts.size(); @@ -437,8 +455,10 @@ public PipelineStats getPipelineStats() totalDrivers, pipelineStatus.getQueuedDrivers(), pipelineStatus.getQueuedPartitionedDrivers(), + pipelineStatus.getQueuedPartitionedSplitsWeight(), pipelineStatus.getRunningDrivers(), pipelineStatus.getRunningPartitionedDrivers(), + pipelineStatus.getRunningPartitionedSplitsWeight(), pipelineStatus.getBlockedDrivers(), completedDrivers, @@ -495,10 +515,12 @@ public MemoryTrackingContext getPipelineMemoryContext() return pipelineMemoryContext; } - private static PipelineStatus getPipelineStatus(Iterator driverContextsIterator, int totalSplits, int completedDrivers, boolean partitioned) + private static PipelineStatus getPipelineStatus(Iterator driverContextsIterator, int totalSplits, int completedDrivers, long activePartitionedSplitsWeight, boolean partitioned) { int runningDrivers = 0; int blockedDrivers = 0; + long runningPartitionedSplitsWeight = 0L; + long blockedPartitionedSplitsWeight = 0L; // When a split for a partitioned pipeline is delivered to a worker, // conceptually, the worker would have an additional driver. // The queuedDrivers field in PipelineStatus is supposed to represent this. @@ -514,24 +536,44 @@ private static PipelineStatus getPipelineStatus(Iterator driverCo } else if (driverContext.isFullyBlocked()) { blockedDrivers++; + if (partitioned) { + blockedPartitionedSplitsWeight += driverContext.getSplitWeight(); + } } else { runningDrivers++; + if (partitioned) { + runningPartitionedSplitsWeight += driverContext.getSplitWeight(); + } } } int queuedDrivers; + int queuedPartitionedSplits; + int runningPartitionedSplits; + long queuedPartitionedSplitsWeight; if (partitioned) { queuedDrivers = totalSplits - runningDrivers - blockedDrivers - completedDrivers; if (queuedDrivers < 0) { // It is possible to observe negative here because inputs to the above expression was not taken in a snapshot. queuedDrivers = 0; } + queuedPartitionedSplitsWeight = activePartitionedSplitsWeight - runningPartitionedSplitsWeight - blockedPartitionedSplitsWeight; + if (queuedDrivers == 0 || queuedPartitionedSplitsWeight < 0) { + // negative or inconsistent count vs weight inputs might occur + queuedPartitionedSplitsWeight = 0; + } + queuedPartitionedSplits = queuedDrivers; + runningPartitionedSplits = runningDrivers; } else { queuedDrivers = physicallyQueuedDrivers; + queuedPartitionedSplits = 0; + queuedPartitionedSplitsWeight = 0; + runningPartitionedSplits = 0; + runningPartitionedSplitsWeight = 0; } - return new PipelineStatus(queuedDrivers, runningDrivers, blockedDrivers, partitioned ? queuedDrivers : 0, partitioned ? runningDrivers : 0); + return new PipelineStatus(queuedDrivers, runningDrivers, blockedDrivers, queuedPartitionedSplits, queuedPartitionedSplitsWeight, runningPartitionedSplits, runningPartitionedSplitsWeight); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/PipelineStats.java b/core/trino-main/src/main/java/io/trino/operator/PipelineStats.java index ea6956ee1370..da2f2c0e70d1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PipelineStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/PipelineStats.java @@ -47,8 +47,10 @@ public class PipelineStats private final int totalDrivers; private final int queuedDrivers; private final int queuedPartitionedDrivers; + private final long queuedPartitionedSplitsWeight; private final int runningDrivers; private final int runningPartitionedDrivers; + private final long runningPartitionedSplitsWeight; private final int blockedDrivers; private final int completedDrivers; @@ -100,8 +102,10 @@ public PipelineStats( @JsonProperty("totalDrivers") int totalDrivers, @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("queuedPartitionedDrivers") int queuedPartitionedDrivers, + @JsonProperty("queuedPartitionedSplitsWeight") long queuedPartitionedSplitsWeight, @JsonProperty("runningDrivers") int runningDrivers, @JsonProperty("runningPartitionedDrivers") int runningPartitionedDrivers, + @JsonProperty("runningPartitionedSplitsWeight") long runningPartitionedSplitsWeight, @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @@ -154,10 +158,14 @@ public PipelineStats( this.queuedDrivers = queuedDrivers; checkArgument(queuedPartitionedDrivers >= 0, "queuedPartitionedDrivers is negative"); this.queuedPartitionedDrivers = queuedPartitionedDrivers; + checkArgument(queuedPartitionedSplitsWeight >= 0, "queuedPartitionedSplitsWeight must be positive"); + this.queuedPartitionedSplitsWeight = queuedPartitionedSplitsWeight; checkArgument(runningDrivers >= 0, "runningDrivers is negative"); this.runningDrivers = runningDrivers; checkArgument(runningPartitionedDrivers >= 0, "runningPartitionedDrivers is negative"); this.runningPartitionedDrivers = runningPartitionedDrivers; + checkArgument(runningPartitionedSplitsWeight >= 0, "runningPartitionedSplitsWeight must be positive"); + this.runningPartitionedSplitsWeight = runningPartitionedSplitsWeight; checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); @@ -260,6 +268,12 @@ public int getQueuedPartitionedDrivers() return queuedPartitionedDrivers; } + @JsonProperty + public long getQueuedPartitionedSplitsWeight() + { + return queuedPartitionedSplitsWeight; + } + @JsonProperty public int getRunningDrivers() { @@ -272,6 +286,12 @@ public int getRunningPartitionedDrivers() return runningPartitionedDrivers; } + @JsonProperty + public long getRunningPartitionedSplitsWeight() + { + return runningPartitionedSplitsWeight; + } + @JsonProperty public int getBlockedDrivers() { @@ -440,8 +460,10 @@ public PipelineStats summarize() totalDrivers, queuedDrivers, queuedPartitionedDrivers, + queuedPartitionedSplitsWeight, runningDrivers, runningPartitionedDrivers, + runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, userMemoryReservation, diff --git a/core/trino-main/src/main/java/io/trino/operator/PipelineStatus.java b/core/trino-main/src/main/java/io/trino/operator/PipelineStatus.java index 993f3c8739a9..41969556e617 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PipelineStatus.java +++ b/core/trino-main/src/main/java/io/trino/operator/PipelineStatus.java @@ -22,15 +22,19 @@ public final class PipelineStatus private final int runningDrivers; private final int blockedDrivers; private final int queuedPartitionedDrivers; + private final long queuedPartitionedSplitsWeight; private final int runningPartitionedDrivers; + private final long runningPartitionedSplitsWeight; - public PipelineStatus(int queuedDrivers, int runningDrivers, int blockedDrivers, int queuedPartitionedDrivers, int runningPartitionedDrivers) + public PipelineStatus(int queuedDrivers, int runningDrivers, int blockedDrivers, int queuedPartitionedDrivers, long queuedPartitionedSplitsWeight, int runningPartitionedDrivers, long runningPartitionedSplitsWeight) { this.queuedDrivers = queuedDrivers; this.runningDrivers = runningDrivers; this.blockedDrivers = blockedDrivers; this.queuedPartitionedDrivers = queuedPartitionedDrivers; + this.queuedPartitionedSplitsWeight = queuedPartitionedSplitsWeight; this.runningPartitionedDrivers = runningPartitionedDrivers; + this.runningPartitionedSplitsWeight = runningPartitionedSplitsWeight; } public int getQueuedDrivers() @@ -53,8 +57,18 @@ public int getQueuedPartitionedDrivers() return queuedPartitionedDrivers; } + public long getQueuedPartitionedSplitsWeight() + { + return queuedPartitionedSplitsWeight; + } + public int getRunningPartitionedDrivers() { return runningPartitionedDrivers; } + + public long getRunningPartitionedSplitsWeight() + { + return runningPartitionedSplitsWeight; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java index c1e66c0dce69..58259ab39266 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java @@ -420,8 +420,10 @@ public TaskStats getTaskStats() int totalDrivers = 0; int queuedDrivers = 0; int queuedPartitionedDrivers = 0; + long queuedPartitionedSplitsWeight = 0; int runningDrivers = 0; int runningPartitionedDrivers = 0; + long runningPartitionedSplitsWeight = 0; int blockedDrivers = 0; int completedDrivers = 0; @@ -455,8 +457,10 @@ public TaskStats getTaskStats() totalDrivers += pipeline.getTotalDrivers(); queuedDrivers += pipeline.getQueuedDrivers(); queuedPartitionedDrivers += pipeline.getQueuedPartitionedDrivers(); + queuedPartitionedSplitsWeight += pipeline.getQueuedPartitionedSplitsWeight(); runningDrivers += pipeline.getRunningDrivers(); runningPartitionedDrivers += pipeline.getRunningPartitionedDrivers(); + runningPartitionedSplitsWeight += pipeline.getRunningPartitionedSplitsWeight(); blockedDrivers += pipeline.getBlockedDrivers(); completedDrivers += pipeline.getCompletedDrivers(); @@ -541,8 +545,10 @@ public TaskStats getTaskStats() totalDrivers, queuedDrivers, queuedPartitionedDrivers, + queuedPartitionedSplitsWeight, runningDrivers, runningPartitionedDrivers, + runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, cumulativeUserMemory.get(), diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java index 068a98c575b4..1acf63089771 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java @@ -45,8 +45,10 @@ public class TaskStats private final int totalDrivers; private final int queuedDrivers; private final int queuedPartitionedDrivers; + private final long queuedPartitionedSplitsWeight; private final int runningDrivers; private final int runningPartitionedDrivers; + private final long runningPartitionedSplitsWeight; private final int blockedDrivers; private final int completedDrivers; @@ -97,8 +99,10 @@ public TaskStats(DateTime createTime, DateTime endTime) 0, 0, 0, + 0L, 0, 0, + 0L, 0, 0, 0.0, @@ -141,8 +145,10 @@ public TaskStats( @JsonProperty("totalDrivers") int totalDrivers, @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("queuedPartitionedDrivers") int queuedPartitionedDrivers, + @JsonProperty("queuedPartitionedSplitsWeight") long queuedPartitionedSplitsWeight, @JsonProperty("runningDrivers") int runningDrivers, @JsonProperty("runningPartitionedDrivers") int runningPartitionedDrivers, + @JsonProperty("runningPartitionedSplitsWeight") long runningPartitionedSplitsWeight, @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @@ -195,11 +201,15 @@ public TaskStats( this.queuedDrivers = queuedDrivers; checkArgument(queuedPartitionedDrivers >= 0, "queuedPartitionedDrivers is negative"); this.queuedPartitionedDrivers = queuedPartitionedDrivers; + checkArgument(queuedPartitionedSplitsWeight >= 0, "queuedPartitionedSplitsWeight must be positive"); + this.queuedPartitionedSplitsWeight = queuedPartitionedSplitsWeight; checkArgument(runningDrivers >= 0, "runningDrivers is negative"); this.runningDrivers = runningDrivers; checkArgument(runningPartitionedDrivers >= 0, "runningPartitionedDrivers is negative"); this.runningPartitionedDrivers = runningPartitionedDrivers; + checkArgument(runningPartitionedSplitsWeight >= 0, "runningPartitionedSplitsWeight must be positive"); + this.runningPartitionedSplitsWeight = runningPartitionedSplitsWeight; checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); this.blockedDrivers = blockedDrivers; @@ -469,12 +479,24 @@ public int getQueuedPartitionedDrivers() return queuedPartitionedDrivers; } + @JsonProperty + public long getQueuedPartitionedSplitsWeight() + { + return queuedPartitionedSplitsWeight; + } + @JsonProperty public int getRunningPartitionedDrivers() { return runningPartitionedDrivers; } + @JsonProperty + public long getRunningPartitionedSplitsWeight() + { + return runningPartitionedSplitsWeight; + } + @JsonProperty public int getFullGcCount() { @@ -500,8 +522,10 @@ public TaskStats summarize() totalDrivers, queuedDrivers, queuedPartitionedDrivers, + queuedPartitionedSplitsWeight, runningDrivers, runningPartitionedDrivers, + runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, cumulativeUserMemory, @@ -544,8 +568,10 @@ public TaskStats summarizeFinal() totalDrivers, queuedDrivers, queuedPartitionedDrivers, + queuedPartitionedSplitsWeight, runningDrivers, runningPartitionedDrivers, + runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, cumulativeUserMemory, diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index a54017659f04..89237a990912 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -37,6 +37,7 @@ import io.trino.execution.FutureStateChange; import io.trino.execution.Lifespan; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; +import io.trino.execution.PartitionedSplitsInfo; import io.trino.execution.RemoteTask; import io.trino.execution.ScheduledSplit; import io.trino.execution.StateMachine.StateChangeListener; @@ -52,6 +53,7 @@ import io.trino.operator.TaskStats; import io.trino.server.DynamicFilterService; import io.trino.server.TaskUpdateRequest; +import io.trino.spi.SplitWeight; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNode; @@ -68,7 +70,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; -import java.util.OptionalInt; +import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.Future; @@ -98,6 +100,7 @@ import static io.trino.execution.TaskStatus.failWith; import static io.trino.server.remotetask.RequestErrorTracker.logError; import static io.trino.util.Failures.toFailure; +import static java.lang.Math.addExact; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -137,6 +140,8 @@ public final class HttpRemoteTask @GuardedBy("this") private volatile int pendingSourceSplitCount; @GuardedBy("this") + private volatile long pendingSourceSplitsWeight; + @GuardedBy("this") private final SetMultimap pendingNoMoreSplitsForLifespan = HashMultimap.create(); @GuardedBy("this") // The keys of this map represent all plan nodes that have "no more splits". @@ -148,7 +153,7 @@ public final class HttpRemoteTask @GuardedBy("this") private boolean splitQueueHasSpace = true; @GuardedBy("this") - private OptionalInt whenSplitQueueHasSpaceThreshold = OptionalInt.empty(); + private OptionalLong whenSplitQueueHasSpaceThreshold = OptionalLong.empty(); private final boolean summarizeTaskInfo; @@ -229,12 +234,20 @@ public HttpRemoteTask( ScheduledSplit scheduledSplit = new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getKey(), entry.getValue()); pendingSplits.put(entry.getKey(), scheduledSplit); } - pendingSourceSplitCount = planFragment.getPartitionedSources().stream() - .filter(initialSplits::containsKey) - .mapToInt(partitionedSource -> initialSplits.get(partitionedSource).size()) - .sum(); maxUnacknowledgedSplits = getMaxUnacknowledgedSplitsPerTask(session); + int pendingSourceSplitCount = 0; + long pendingSourceSplitsWeight = 0; + for (PlanNodeId planNodeId : planFragment.getPartitionedSources()) { + Collection tableScanSplits = initialSplits.get(planNodeId); + if (tableScanSplits != null && !tableScanSplits.isEmpty()) { + pendingSourceSplitCount += tableScanSplits.size(); + pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, SplitWeight.rawValueSum(tableScanSplits, Split::getSplitWeight)); + } + } + this.pendingSourceSplitCount = pendingSourceSplitCount; + this.pendingSourceSplitsWeight = pendingSourceSplitsWeight; + List bufferStates = outputBuffers.getBuffers() .keySet().stream() .map(outputId -> new BufferInfo(outputId, false, 0, 0, PageBufferInfo.empty())) @@ -286,7 +299,7 @@ public HttpRemoteTask( cleanUpTask(); } else { - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } }); @@ -297,7 +310,7 @@ public HttpRemoteTask( outboundDynamicFilterIds, outboundDynamicFiltersCollector::updateDomains); - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } } @@ -354,17 +367,23 @@ public synchronized void addSplits(Multimap splitsBySource) for (Entry> entry : splitsBySource.asMap().entrySet()) { PlanNodeId sourceId = entry.getKey(); Collection splits = entry.getValue(); + boolean isPartitionedSource = planFragment.isPartitionedSources(sourceId); checkState(!noMoreSplits.containsKey(sourceId), "noMoreSplits has already been set for %s", sourceId); int added = 0; + long addedWeight = 0; for (Split split : splits) { if (pendingSplits.put(sourceId, new ScheduledSplit(nextSplitId.getAndIncrement(), sourceId, split))) { - added++; + if (isPartitionedSource) { + added++; + addedWeight = addExact(addedWeight, split.getSplitWeight().getRawValue()); + } } } - if (planFragment.isPartitionedSources(sourceId)) { + if (isPartitionedSource) { pendingSourceSplitCount += added; - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, addedWeight); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); } needsUpdate = true; } @@ -408,23 +427,37 @@ public synchronized void setOutputBuffers(OutputBuffers newOutputBuffers) } @Override - public int getPartitionedSplitCount() + public PartitionedSplitsInfo getPartitionedSplitsInfo() { TaskStatus taskStatus = getTaskStatus(); if (taskStatus.getState().isDone()) { - return 0; + return PartitionedSplitsInfo.forZeroSplits(); } - return getPendingSourceSplitCount() + taskStatus.getQueuedPartitionedDrivers() + taskStatus.getRunningPartitionedDrivers(); + PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); + int count = unacknowledgedSplitsInfo.getCount() + taskStatus.getQueuedPartitionedDrivers() + taskStatus.getRunningPartitionedDrivers(); + long weight = unacknowledgedSplitsInfo.getWeightSum() + taskStatus.getQueuedPartitionedSplitsWeight() + taskStatus.getRunningPartitionedSplitsWeight(); + return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); + } + + @SuppressWarnings("FieldAccessNotGuarded") + public PartitionedSplitsInfo getUnacknowledgedPartitionedSplitsInfo() + { + int count = pendingSourceSplitCount; + long weight = pendingSourceSplitsWeight; + return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); } @Override - public int getQueuedPartitionedSplitCount() + public PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() { TaskStatus taskStatus = getTaskStatus(); if (taskStatus.getState().isDone()) { - return 0; + return PartitionedSplitsInfo.forZeroSplits(); } - return getPendingSourceSplitCount() + taskStatus.getQueuedPartitionedDrivers(); + PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); + int count = unacknowledgedSplitsInfo.getCount() + taskStatus.getQueuedPartitionedDrivers(); + long weight = unacknowledgedSplitsInfo.getWeightSum() + taskStatus.getQueuedPartitionedSplitsWeight(); + return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); } @Override @@ -439,6 +472,21 @@ private int getPendingSourceSplitCount() return pendingSourceSplitCount; } + private long getQueuedPartitionedSplitsWeight() + { + TaskStatus taskStatus = getTaskStatus(); + if (taskStatus.getState().isDone()) { + return 0; + } + return getPendingSourceSplitsWeight() + taskStatus.getQueuedPartitionedSplitsWeight(); + } + + @SuppressWarnings("FieldAccessNotGuarded") + private long getPendingSourceSplitsWeight() + { + return pendingSourceSplitsWeight; + } + @Override public void addStateChangeListener(StateChangeListener stateChangeListener) { @@ -454,13 +502,13 @@ public void addFinalTaskInfoListener(StateChangeListener stateChangeLi } @Override - public synchronized ListenableFuture whenSplitQueueHasSpace(int threshold) + public synchronized ListenableFuture whenSplitQueueHasSpace(long weightThreshold) { if (whenSplitQueueHasSpaceThreshold.isPresent()) { - checkArgument(threshold == whenSplitQueueHasSpaceThreshold.getAsInt(), "Multiple split queue space notification thresholds not supported"); + checkArgument(weightThreshold == whenSplitQueueHasSpaceThreshold.getAsLong(), "Multiple split queue space notification thresholds not supported"); } else { - whenSplitQueueHasSpaceThreshold = OptionalInt.of(threshold); + whenSplitQueueHasSpaceThreshold = OptionalLong.of(weightThreshold); updateSplitQueueSpace(); } if (splitQueueHasSpace) { @@ -473,7 +521,7 @@ private synchronized void updateSplitQueueSpace() { // Must check whether the unacknowledged split count threshold is reached even without listeners registered yet splitQueueHasSpace = getUnacknowledgedPartitionedSplitCount() < maxUnacknowledgedSplits && - (whenSplitQueueHasSpaceThreshold.isEmpty() || getQueuedPartitionedSplitCount() < whenSplitQueueHasSpaceThreshold.getAsInt()); + (whenSplitQueueHasSpaceThreshold.isEmpty() || getQueuedPartitionedSplitsWeight() < whenSplitQueueHasSpaceThreshold.getAsLong()); // Only trigger notifications if a listener might be registered if (splitQueueHasSpace && whenSplitQueueHasSpaceThreshold.isPresent()) { whenSplitQueueHasSpace.complete(null, executor); @@ -487,10 +535,15 @@ private synchronized void processTaskUpdate(TaskInfo newValue, List // remove acknowledged splits, which frees memory for (TaskSource source : sources) { PlanNodeId planNodeId = source.getPlanNodeId(); + boolean isPartitionedSource = planFragment.isPartitionedSources(planNodeId); int removed = 0; + long removedWeight = 0; for (ScheduledSplit split : source.getSplits()) { if (pendingSplits.remove(planNodeId, split)) { - removed++; + if (isPartitionedSource) { + removed++; + removedWeight = addExact(removedWeight, split.getSplit().getSplitWeight().getRawValue()); + } } } if (source.isNoMoreSplits()) { @@ -499,12 +552,13 @@ private synchronized void processTaskUpdate(TaskInfo newValue, List for (Lifespan lifespan : source.getNoMoreSplitsForLifespan()) { pendingNoMoreSplitsForLifespan.remove(planNodeId, lifespan); } - if (planFragment.isPartitionedSources(planNodeId)) { + if (isPartitionedSource) { pendingSourceSplitCount -= removed; + pendingSourceSplitsWeight -= removedWeight; } } // Update node level split tracker before split queue space to ensure it's up to date before waking up the scheduler - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } @@ -639,7 +693,8 @@ private synchronized void cleanUpTask() // clear pending splits to free memory pendingSplits.clear(); pendingSourceSplitCount = 0; - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + pendingSourceSplitsWeight = 0; + partitionedSplitCountTracker.setPartitionedSplits(PartitionedSplitsInfo.forZeroSplits()); splitQueueHasSpace = true; whenSplitQueueHasSpace.complete(null, executor); diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index f12cfecd4114..8d1adc1e2f26 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -37,6 +37,7 @@ import io.trino.metadata.Split; import io.trino.operator.TaskContext; import io.trino.operator.TaskStats; +import io.trino.spi.SplitWeight; import io.trino.spi.memory.MemoryPoolId; import io.trino.spiller.SpillSpaceTracker; import io.trino.sql.planner.Partitioning; @@ -85,6 +86,7 @@ import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static io.trino.util.Failures.toFailures; +import static java.lang.Math.addExact; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -217,7 +219,7 @@ public MockRemoteTask( this.nodeId = requireNonNull(nodeId, "nodeId is null"); splits.putAll(initialSplits); this.partitionedSplitCountTracker = requireNonNull(partitionedSplitCountTracker, "partitionedSplitCountTracker is null"); - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } @@ -261,7 +263,9 @@ public TaskInfo getTaskInfo() DataSize.ofBytes(0), 0, new Duration(0, MILLISECONDS), - INITIAL_DYNAMIC_FILTERS_VERSION), + INITIAL_DYNAMIC_FILTERS_VERSION, + 0L, + 0L), DateTime.now(), outputBuffer.getInfo(), ImmutableSet.of(), @@ -273,6 +277,8 @@ public TaskInfo getTaskInfo() public TaskStatus getTaskStatus() { TaskStats stats = taskContext.getTaskStats(); + PartitionedSplitsInfo combinedSplitsInfo = getPartitionedSplitsInfo(); + PartitionedSplitsInfo queuedSplitsInfo = getQueuedPartitionedSplitsInfo(); return new TaskStatus(taskStateMachine.getTaskId(), TASK_INSTANCE_ID, nextTaskInfoVersion.get(), @@ -281,8 +287,8 @@ public TaskStatus getTaskStatus() nodeId, ImmutableSet.of(), ImmutableList.of(), - stats.getQueuedPartitionedDrivers(), - stats.getRunningPartitionedDrivers(), + queuedSplitsInfo.getCount(), + combinedSplitsInfo.getCount() - queuedSplitsInfo.getCount(), isOutputBufferOverUtilized, stats.getPhysicalWrittenDataSize(), stats.getUserMemoryReservation(), @@ -290,12 +296,14 @@ public TaskStatus getTaskStatus() stats.getRevocableMemoryReservation(), 0, new Duration(0, MILLISECONDS), - INITIAL_DYNAMIC_FILTERS_VERSION); + INITIAL_DYNAMIC_FILTERS_VERSION, + queuedSplitsInfo.getWeightSum(), + combinedSplitsInfo.getWeightSum() - queuedSplitsInfo.getWeightSum()); } private synchronized void updateSplitQueueSpace() { - if (unacknowledgedSplits < maxUnacknowledgedSplits && getQueuedPartitionedSplitCount() < 9) { + if (unacknowledgedSplits < maxUnacknowledgedSplits && getQueuedPartitionedSplitsInfo().getWeightSum() < 900L) { if (!whenSplitQueueHasSpace.isDone()) { whenSplitQueueHasSpace.set(null); } @@ -324,7 +332,7 @@ public synchronized void clearSplits() { unacknowledgedSplits = 0; splits.clear(); - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(PartitionedSplitsInfo.forZeroSplits()); runningDrivers = 0; updateSplitQueueSpace(); } @@ -371,7 +379,7 @@ public void addSplits(Multimap splits) synchronized (this) { this.splits.putAll(splits); } - partitionedSplitCountTracker.setPartitionedSplitCount(getPartitionedSplitCount()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } @@ -422,7 +430,7 @@ public void addFinalTaskInfoListener(StateChangeListener stateChangeLi } @Override - public synchronized ListenableFuture whenSplitQueueHasSpace(int threshold) + public synchronized ListenableFuture whenSplitQueueHasSpace(long weightThreshold) { return nonCancellationPropagating(whenSplitQueueHasSpace); } @@ -441,28 +449,45 @@ public void abort() } @Override - public int getPartitionedSplitCount() + public PartitionedSplitsInfo getPartitionedSplitsInfo() { if (taskStateMachine.getState().isDone()) { - return 0; + return PartitionedSplitsInfo.forZeroSplits(); } synchronized (this) { int count = 0; - for (PlanNodeId partitionedSource : fragment.getPartitionedSources()) { - Collection partitionedSplits = splits.get(partitionedSource); + long weight = 0; + for (PlanNodeId tableScanPlanNodeId : fragment.getPartitionedSources()) { + Collection partitionedSplits = splits.get(tableScanPlanNodeId); count += partitionedSplits.size(); + weight = addExact(weight, SplitWeight.rawValueSum(partitionedSplits, Split::getSplitWeight)); } - return count; + return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); } } @Override - public synchronized int getQueuedPartitionedSplitCount() + public synchronized PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() { if (taskStateMachine.getState().isDone()) { - return 0; + return PartitionedSplitsInfo.forZeroSplits(); } - return getPartitionedSplitCount() - runningDrivers; + // Let's consider the first drivers encountered to be "running" + int remainingRunning = runningDrivers; + int queuedCount = 0; + long queuedWeight = 0; + for (PlanNodeId tableScanPlanNodeId : fragment.getPartitionedSources()) { + for (Split split : splits.get(tableScanPlanNodeId)) { + if (remainingRunning > 0) { + remainingRunning--; + } + else { + queuedCount++; + queuedWeight = addExact(queuedWeight, split.getSplitWeight().getRawValue()); + } + } + } + return PartitionedSplitsInfo.forSplitCountAndWeightSum(queuedCount, queuedWeight); } @Override diff --git a/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java index 06f5b1b3e372..06930259922f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java @@ -42,6 +42,7 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.connector.ConnectorSplit; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingSession; @@ -341,7 +342,7 @@ public void testMaxSplitsPerNode() remoteTask1.abort(); remoteTask2.abort(); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), 0); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), PartitionedSplitsInfo.forZeroSplits()); } @Test @@ -407,7 +408,7 @@ public void testMaxSplitsPerNodePerTask() for (RemoteTask task : tasks) { task.abort(); } - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), 0); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), PartitionedSplitsInfo.forZeroSplits()); } @Test @@ -424,13 +425,13 @@ public void testTaskCompletion() ImmutableList.of(new Split(CONNECTOR_ID, new TestSplitRemote(), Lifespan.taskWide())), nodeTaskMap.createPartitionedSplitCountTracker(chosenNode, taskId)); nodeTaskMap.addTask(chosenNode, remoteTask); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 1); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode).getCount(), 1); remoteTask.abort(); MILLISECONDS.sleep(100); // Sleep until cache expires - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), PartitionedSplitsInfo.forZeroSplits()); remoteTask.abort(); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), PartitionedSplitsInfo.forZeroSplits()); } @Test @@ -457,12 +458,12 @@ public void testSplitCount() nodeTaskMap.addTask(chosenNode, remoteTask1); nodeTaskMap.addTask(chosenNode, remoteTask2); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 3); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode).getCount(), 3); remoteTask1.abort(); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 1); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode).getCount(), 1); remoteTask2.abort(); - assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0); + assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), PartitionedSplitsInfo.forZeroSplits()); } @Test @@ -869,6 +870,7 @@ private static class TestSplitLocal implements ConnectorSplit { private final HostAddress address; + private final SplitWeight splitWeight; private TestSplitLocal() { @@ -876,8 +878,14 @@ private TestSplitLocal() } private TestSplitLocal(HostAddress address) + { + this(address, SplitWeight.standard()); + } + + private TestSplitLocal(HostAddress address, SplitWeight splitWeight) { this.address = requireNonNull(address, "address is null"); + this.splitWeight = requireNonNull(splitWeight, "splitWeight is null"); } @Override @@ -898,6 +906,12 @@ public Object getInfo() return this; } + @Override + public SplitWeight getSplitWeight() + { + return splitWeight; + } + @Override public String toString() { @@ -933,6 +947,7 @@ private static class TestSplitRemote implements ConnectorSplit { private final List hosts; + private final SplitWeight splitWeight; TestSplitRemote() { @@ -944,8 +959,14 @@ private static class TestSplitRemote } TestSplitRemote(HostAddress host) + { + this(host, SplitWeight.standard()); + } + + TestSplitRemote(HostAddress host, SplitWeight splitWeight) { this.hosts = ImmutableList.of(requireNonNull(host, "host is null")); + this.splitWeight = requireNonNull(splitWeight, "splitWeight is null"); } @Override @@ -965,6 +986,12 @@ public Object getInfo() { return this; } + + @Override + public SplitWeight getSplitWeight() + { + return splitWeight; + } } private static class TestNetworkTopology diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index 99da1c6909de..a5d6aee9523c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -26,6 +26,7 @@ import io.trino.execution.MockRemoteTaskFactory; import io.trino.execution.MockRemoteTaskFactory.MockRemoteTask; import io.trino.execution.NodeTaskMap; +import io.trino.execution.PartitionedSplitsInfo; import io.trino.execution.RemoteTask; import io.trino.execution.SqlStageExecution; import io.trino.execution.StageId; @@ -194,7 +195,8 @@ public void testScheduleSplitsOneAtATime() } for (RemoteTask remoteTask : stage.getAllTasks()) { - assertEquals(remoteTask.getPartitionedSplitCount(), 20); + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 20); } stage.abort(); @@ -231,7 +233,8 @@ public void testScheduleSplitsBatched() } for (RemoteTask remoteTask : stage.getAllTasks()) { - assertEquals(remoteTask.getPartitionedSplitCount(), 20); + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 20); } stage.abort(); @@ -263,7 +266,8 @@ public void testScheduleSplitsBlock() } for (RemoteTask remoteTask : stage.getAllTasks()) { - assertEquals(remoteTask.getPartitionedSplitCount(), 20); + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 20); } // todo rewrite MockRemoteTask to fire a tate transition when splits are cleared, and then validate blocked future completes @@ -295,7 +299,8 @@ public void testScheduleSplitsBlock() } for (RemoteTask remoteTask : stage.getAllTasks()) { - assertEquals(remoteTask.getPartitionedSplitCount(), 20); + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 20); } stage.abort(); @@ -367,7 +372,8 @@ public void testBalancedSplitAssignment() assertEquals(scheduleResult.getNewTasks().size(), 3); assertEquals(firstStage.getAllTasks().size(), 3); for (RemoteTask remoteTask : firstStage.getAllTasks()) { - assertEquals(remoteTask.getPartitionedSplitCount(), 5); + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 5); } // Add new node @@ -385,7 +391,7 @@ public void testBalancedSplitAssignment() assertEquals(scheduleResult.getNewTasks().size(), 1); assertEquals(secondStage.getAllTasks().size(), 1); RemoteTask task = secondStage.getAllTasks().get(0); - assertEquals(task.getPartitionedSplitCount(), 5); + assertEquals(task.getPartitionedSplitsInfo().getCount(), 5); firstStage.abort(); secondStage.abort(); @@ -421,6 +427,10 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() assertEquals(scheduleResult.getBlockedReason().get(), SPLIT_QUEUES_FULL); assertEquals(scheduleResult.getNewTasks().size(), 3); assertEquals(scheduleResult.getSplitsScheduled(), 300); + for (RemoteTask remoteTask : scheduleResult.getNewTasks()) { + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 100); + } // new node added - the pending splits should go to it since the child tasks are not blocked nodeManager.addNode(CONNECTOR_ID, new InternalNode("other4", URI.create("http://127.0.0.4:14"), NodeVersion.UNKNOWN, false)); @@ -460,6 +470,10 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() assertEquals(scheduleResult.getBlockedReason().get(), SPLIT_QUEUES_FULL); assertEquals(scheduleResult.getNewTasks().size(), 3); assertEquals(scheduleResult.getSplitsScheduled(), 300); + for (RemoteTask remoteTask : scheduleResult.getNewTasks()) { + PartitionedSplitsInfo splitsInfo = remoteTask.getPartitionedSplitsInfo(); + assertEquals(splitsInfo.getCount(), 100); + } // new node added but 1 child's output buffer is overutilized - so lockdown the tasks nodeManager.addNode(CONNECTOR_ID, new InternalNode("other4", URI.create("http://127.0.0.4:14"), NodeVersion.UNKNOWN, false)); @@ -514,7 +528,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() private static void assertPartitionedSplitCount(SqlStageExecution stage, int expectedPartitionedSplitCount) { - assertEquals(stage.getAllTasks().stream().mapToInt(RemoteTask::getPartitionedSplitCount).sum(), expectedPartitionedSplitCount); + assertEquals(stage.getAllTasks().stream().mapToInt(remoteTask -> remoteTask.getPartitionedSplitsInfo().getCount()).sum(), expectedPartitionedSplitCount); } private static void assertEffectivelyFinished(ScheduleResult scheduleResult, StageScheduler scheduler) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestPipelineStats.java b/core/trino-main/src/test/java/io/trino/operator/TestPipelineStats.java index 8998c3eaa634..d2e4c900bca7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestPipelineStats.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestPipelineStats.java @@ -44,8 +44,10 @@ public class TestPipelineStats 1, 2, 1, + 21L, 3, 2, + 22L, 19, 4, @@ -105,8 +107,10 @@ public static void assertExpectedPipelineStats(PipelineStats actual) assertEquals(actual.getTotalDrivers(), 1); assertEquals(actual.getQueuedDrivers(), 2); assertEquals(actual.getQueuedPartitionedDrivers(), 1); + assertEquals(actual.getQueuedPartitionedSplitsWeight(), 21L); assertEquals(actual.getRunningDrivers(), 3); assertEquals(actual.getRunningPartitionedDrivers(), 2); + assertEquals(actual.getRunningPartitionedSplitsWeight(), 22L); assertEquals(actual.getBlockedDrivers(), 19); assertEquals(actual.getCompletedDrivers(), 4); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java b/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java index 2890fd719194..59184d013401 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java @@ -40,8 +40,10 @@ public class TestTaskStats 6, 7, 5, + 28L, 8, 6, + 29L, 24, 10, @@ -103,8 +105,10 @@ public static void assertExpectedTaskStats(TaskStats actual) assertEquals(actual.getTotalDrivers(), 6); assertEquals(actual.getQueuedDrivers(), 7); assertEquals(actual.getQueuedPartitionedDrivers(), 5); + assertEquals(actual.getQueuedPartitionedSplitsWeight(), 28L); assertEquals(actual.getRunningDrivers(), 8); assertEquals(actual.getRunningPartitionedDrivers(), 6); + assertEquals(actual.getRunningPartitionedSplitsWeight(), 29L); assertEquals(actual.getBlockedDrivers(), 24); assertEquals(actual.getCompletedDrivers(), 10); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingOperatorContext.java b/core/trino-main/src/test/java/io/trino/operator/TestingOperatorContext.java index 9be63526dc20..b147efadd27d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestingOperatorContext.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestingOperatorContext.java @@ -53,7 +53,8 @@ public static OperatorContext create(ScheduledExecutorService scheduledExecutor) executor, scheduledExecutor, pipelineMemoryContext, - Lifespan.taskWide()); + Lifespan.taskWide(), + 0L); OperatorContext operatorContext = driverContext.addOperatorContext( 1, diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index f6dbf96a9340..d7c934be0085 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -762,7 +762,9 @@ private TaskStatus buildTaskStatus() initialTaskStatus.getRevocableMemoryReservation(), initialTaskStatus.getFullGcCount(), initialTaskStatus.getFullGcTime(), - dynamicFilterDomains.map(VersionedDynamicFilterDomains::getVersion).orElse(INITIAL_DYNAMIC_FILTERS_VERSION)); + dynamicFilterDomains.map(VersionedDynamicFilterDomains::getVersion).orElse(INITIAL_DYNAMIC_FILTERS_VERSION), + initialTaskStatus.getQueuedPartitionedSplitsWeight(), + initialTaskStatus.getRunningPartitionedSplitsWeight()); } private static class DynamicFiltersFetchRequest diff --git a/core/trino-spi/src/main/java/io/trino/spi/SplitWeight.java b/core/trino-spi/src/main/java/io/trino/spi/SplitWeight.java new file mode 100644 index 000000000000..05d3c075bebf --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/SplitWeight.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.math.BigDecimal; +import java.util.Collection; +import java.util.function.Function; + +import static java.lang.Math.addExact; +import static java.lang.Math.multiplyExact; + +public final class SplitWeight +{ + private static final long UNIT_VALUE = 100; + private static final int UNIT_SCALE = 2; // Decimal scale such that (10 ^ UNIT_SCALE) == UNIT_VALUE + private static final SplitWeight STANDARD_WEIGHT = new SplitWeight(UNIT_VALUE); + + private final long value; + + private SplitWeight(long value) + { + if (value <= 0) { + throw new IllegalArgumentException("value must be > 0, found: " + value); + } + this.value = value; + } + + /** + * @return The internal integer representation for this weight value + */ + @JsonValue + public long getRawValue() + { + return value; + } + + @Override + public boolean equals(Object other) + { + if (!(other instanceof SplitWeight)) { + return false; + } + return this.value == ((SplitWeight) other).value; + } + + @Override + public int hashCode() + { + return Long.hashCode(value); + } + + @Override + public String toString() + { + if (value == UNIT_VALUE) { + return "1"; + } + return BigDecimal.valueOf(value, -UNIT_SCALE).stripTrailingZeros().toPlainString(); + } + + @JsonCreator + public static SplitWeight fromRawValue(long value) + { + return value == UNIT_VALUE ? STANDARD_WEIGHT : new SplitWeight(value); + } + + public static SplitWeight fromProportion(double weight) + { + if (weight <= 0 || !Double.isFinite(weight)) { + throw new IllegalArgumentException("Invalid weight: " + weight); + } + // Must round up to avoid small weights rounding to 0 + return fromRawValue((long) Math.ceil(weight * UNIT_VALUE)); + } + + public static SplitWeight standard() + { + return STANDARD_WEIGHT; + } + + public static long rawValueForStandardSplitCount(int splitCount) + { + if (splitCount < 0) { + throw new IllegalArgumentException("splitCount must be >= 0, found: " + splitCount); + } + return multiplyExact(splitCount, UNIT_VALUE); + } + + public static long rawValueSum(Collection collection, Function getter) + { + long sum = 0; + for (T item : collection) { + long value = getter.apply(item).getRawValue(); + sum = addExact(sum, value); + } + return sum; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplit.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplit.java index 996d16b679b0..0b5fa396eb99 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplit.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplit.java @@ -14,6 +14,7 @@ package io.trino.spi.connector; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import java.util.List; @@ -24,4 +25,9 @@ public interface ConnectorSplit List getAddresses(); Object getInfo(); + + default SplitWeight getSplitWeight() + { + return SplitWeight.standard(); + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConfig.java index 62cde05dc0c3..dbe3ae56a901 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConfig.java @@ -29,6 +29,8 @@ import javax.annotation.Nullable; import javax.validation.constraints.AssertTrue; +import javax.validation.constraints.DecimalMax; +import javax.validation.constraints.DecimalMin; import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; @@ -153,6 +155,9 @@ public class HiveConfig private boolean legacyHiveViewTranslation; private DataSize targetMaxFileSize = DataSize.of(1, GIGABYTE); + private boolean sizeBasedSplitWeightsEnabled = true; + private double minimumAssignedSplitWeight = 0.05; + public int getMaxInitialSplits() { return maxInitialSplits; @@ -1085,4 +1090,31 @@ public boolean isLegacyHiveViewTranslation() { return this.legacyHiveViewTranslation; } + + @Config("hive.size-based-split-weights-enabled") + public HiveConfig setSizeBasedSplitWeightsEnabled(boolean sizeBasedSplitWeightsEnabled) + { + this.sizeBasedSplitWeightsEnabled = sizeBasedSplitWeightsEnabled; + return this; + } + + public boolean isSizeBasedSplitWeightsEnabled() + { + return sizeBasedSplitWeightsEnabled; + } + + @Config("hive.minimum-assigned-split-weight") + @ConfigDescription("Minimum weight that a split can be assigned when size based split weights are enabled") + public HiveConfig setMinimumAssignedSplitWeight(double minimumAssignedSplitWeight) + { + this.minimumAssignedSplitWeight = minimumAssignedSplitWeight; + return this; + } + + @DecimalMax("1") + @DecimalMin(value = "0", inclusive = false) + public double getMinimumAssignedSplitWeight() + { + return minimumAssignedSplitWeight; + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java index 2c6c006ac3af..f1c83bcdbb15 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java @@ -41,6 +41,7 @@ import static io.trino.plugin.base.session.PropertyMetadataUtil.durationProperty; import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static io.trino.spi.session.PropertyMetadata.booleanProperty; +import static io.trino.spi.session.PropertyMetadata.doubleProperty; import static io.trino.spi.session.PropertyMetadata.enumProperty; import static io.trino.spi.session.PropertyMetadata.integerProperty; import static io.trino.spi.session.PropertyMetadata.stringProperty; @@ -110,6 +111,8 @@ public final class HiveSessionProperties private static final String DYNAMIC_FILTERING_PROBE_BLOCKING_TIMEOUT = "dynamic_filtering_probe_blocking_timeout"; private static final String OPTIMIZE_SYMLINK_LISTING = "optimize_symlink_listing"; private static final String LEGACY_HIVE_VIEW_TRANSLATION = "legacy_hive_view_translation"; + public static final String SIZE_BASED_SPLIT_WEIGHTS_ENABLED = "size_based_split_weights_enabled"; + public static final String MINIMUM_ASSIGNED_SPLIT_WEIGHT = "minimum_assigned_split_weight"; private final List> sessionProperties; @@ -458,6 +461,21 @@ public HiveSessionProperties( LEGACY_HIVE_VIEW_TRANSLATION, "Use legacy Hive view translation mechanism", hiveConfig.isLegacyHiveViewTranslation(), + false), + booleanProperty( + SIZE_BASED_SPLIT_WEIGHTS_ENABLED, + "Enable estimating split weights based on size in bytes", + hiveConfig.isSizeBasedSplitWeightsEnabled(), + false), + doubleProperty( + MINIMUM_ASSIGNED_SPLIT_WEIGHT, + "Minimum assigned split weight when size based split weighting is enabled", + hiveConfig.getMinimumAssignedSplitWeight(), + value -> { + if (!Double.isFinite(value) || value <= 0 || value > 1) { + throw new TrinoException(INVALID_SESSION_PROPERTY, format("%s must be > 0 and <= 1.0: %s", MINIMUM_ASSIGNED_SPLIT_WEIGHT, value)); + } + }, false)); } @@ -765,4 +783,14 @@ public static boolean isLegacyHiveViewTranslation(ConnectorSession session) { return session.getProperty(LEGACY_HIVE_VIEW_TRANSLATION, Boolean.class); } + + public static boolean isSizeBasedSplitWeightsEnabled(ConnectorSession session) + { + return session.getProperty(SIZE_BASED_SPLIT_WEIGHTS_ENABLED, Boolean.class); + } + + public static double getMinimumAssignedSplitWeight(ConnectorSession session) + { + return session.getProperty(MINIMUM_ASSIGNED_SPLIT_WEIGHT, Double.class); + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplit.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplit.java index 5bf3ae0d17ea..50661112dfb7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplit.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplit.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.connector.ConnectorSplit; import java.util.List; @@ -55,6 +56,7 @@ public class HiveSplit private final boolean s3SelectPushdownEnabled; private final Optional acidInfo; private final long splitNumber; + private final SplitWeight splitWeight; @JsonCreator public HiveSplit( @@ -77,7 +79,8 @@ public HiveSplit( @JsonProperty("bucketValidation") Optional bucketValidation, @JsonProperty("s3SelectPushdownEnabled") boolean s3SelectPushdownEnabled, @JsonProperty("acidInfo") Optional acidInfo, - @JsonProperty("splitNumber") long splitNumber) + @JsonProperty("splitNumber") long splitNumber, + @JsonProperty("splitWeight") SplitWeight splitWeight) { checkArgument(start >= 0, "start must be positive"); checkArgument(length >= 0, "length must be positive"); @@ -115,6 +118,7 @@ public HiveSplit( this.s3SelectPushdownEnabled = s3SelectPushdownEnabled; this.acidInfo = acidInfo; this.splitNumber = splitNumber; + this.splitWeight = requireNonNull(splitWeight, "splitWeight is null"); } @JsonProperty @@ -244,6 +248,13 @@ public long getSplitNumber() return splitNumber; } + @JsonProperty + @Override + public SplitWeight getSplitWeight() + { + return splitWeight; + } + @Override public Object getInfo() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java index 68c5758adcc8..379bcdeb5931 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitSource.java @@ -23,6 +23,7 @@ import io.trino.plugin.hive.InternalHiveSplit.InternalHiveBlock; import io.trino.plugin.hive.util.AsyncQueue; import io.trino.plugin.hive.util.AsyncQueue.BorrowResult; +import io.trino.plugin.hive.util.SizeBasedSplitWeightProvider; import io.trino.plugin.hive.util.ThrottledAsyncQueue; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPartitionHandle; @@ -56,6 +57,8 @@ import static io.trino.plugin.hive.HiveErrorCode.HIVE_UNKNOWN_ERROR; import static io.trino.plugin.hive.HiveSessionProperties.getMaxInitialSplitSize; import static io.trino.plugin.hive.HiveSessionProperties.getMaxSplitSize; +import static io.trino.plugin.hive.HiveSessionProperties.getMinimumAssignedSplitWeight; +import static io.trino.plugin.hive.HiveSessionProperties.isSizeBasedSplitWeightsEnabled; import static io.trino.plugin.hive.HiveSplitSource.StateKind.CLOSED; import static io.trino.plugin.hive.HiveSplitSource.StateKind.FAILED; import static io.trino.plugin.hive.HiveSplitSource.StateKind.INITIAL; @@ -88,6 +91,7 @@ class HiveSplitSource private final CounterStat highMemorySplitSourceCounter; private final AtomicBoolean loggedHighMemoryWarning = new AtomicBoolean(); + private final HiveSplitWeightProvider splitWeightProvider; private HiveSplitSource( ConnectorSession session, @@ -114,6 +118,7 @@ private HiveSplitSource( this.maxInitialSplitSize = getMaxInitialSplitSize(session); this.remainingInitialSplits = new AtomicInteger(maxInitialSplits); this.numberOfProcessedSplits = new AtomicLong(0); + this.splitWeightProvider = isSizeBasedSplitWeightsEnabled(session) ? new SizeBasedSplitWeightProvider(getMinimumAssignedSplitWeight(session), maxSplitSize) : HiveSplitWeightProvider.uniformStandardWeightProvider(); } public static HiveSplitSource allAtOnce( @@ -384,7 +389,8 @@ else if (maxSplitBytes * 2 >= remainingBlockBytes) { internalSplit.getBucketValidation(), internalSplit.isS3SelectPushdownEnabled(), internalSplit.getAcidInfo(), - numberOfProcessedSplits.getAndIncrement())); + numberOfProcessedSplits.getAndIncrement(), + splitWeightProvider.weightForSplitSizeInBytes(splitBytes))); internalSplit.increaseStart(splitBytes); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitWeightProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitWeightProvider.java new file mode 100644 index 000000000000..03f6fa418d84 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitWeightProvider.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.spi.SplitWeight; + +public interface HiveSplitWeightProvider +{ + SplitWeight weightForSplitSizeInBytes(long splitSizeInBytes); + + static HiveSplitWeightProvider uniformStandardWeightProvider() + { + return (splitSizeInBytes) -> SplitWeight.standard(); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SizeBasedSplitWeightProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SizeBasedSplitWeightProvider.java new file mode 100644 index 000000000000..4eb53c3435e4 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SizeBasedSplitWeightProvider.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.util; + +import io.airlift.units.DataSize; +import io.trino.plugin.hive.HiveSplitWeightProvider; +import io.trino.spi.SplitWeight; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class SizeBasedSplitWeightProvider + implements HiveSplitWeightProvider +{ + private final double minimumWeight; + private final double targetSplitSizeInBytes; + + public SizeBasedSplitWeightProvider(double minimumWeight, DataSize targetSplitSize) + { + checkArgument(Double.isFinite(minimumWeight) && minimumWeight > 0 && minimumWeight <= 1, "minimumWeight must be > 0 and <= 1, found: %s", minimumWeight); + this.minimumWeight = minimumWeight; + long targetSizeInBytes = requireNonNull(targetSplitSize, "targetSplitSize is null").toBytes(); + checkArgument(targetSizeInBytes > 0, "targetSplitSize must be > 0, found: %s", targetSplitSize); + this.targetSplitSizeInBytes = (double) targetSizeInBytes; + } + + @Override + public SplitWeight weightForSplitSizeInBytes(long splitSizeInBytes) + { + double computedWeight = splitSizeInBytes / targetSplitSizeInBytes; + // Clamp the value be between the minimum weight and 1.0 (standard weight) + return SplitWeight.fromProportion(Math.min(Math.max(computedWeight, minimumWeight), 1.0)); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java index da57726b9488..9cd67bcd111f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConfig.java @@ -104,7 +104,9 @@ public void testDefaults() .setDynamicFilteringProbeBlockingTimeout(new Duration(0, TimeUnit.MINUTES)) .setTimestampPrecision(HiveTimestampPrecision.DEFAULT_PRECISION) .setOptimizeSymlinkListing(true) - .setLegacyHiveViewTranslation(false)); + .setLegacyHiveViewTranslation(false) + .setSizeBasedSplitWeightsEnabled(true) + .setMinimumAssignedSplitWeight(0.05)); } @Test @@ -180,6 +182,8 @@ public void testExplicitPropertyMappings() .put("hive.timestamp-precision", "NANOSECONDS") .put("hive.optimize-symlink-listing", "false") .put("hive.legacy-hive-view-translation", "true") + .put("hive.size-based-split-weights-enabled", "false") + .put("hive.minimum-assigned-split-weight", "1.0") .build(); HiveConfig expected = new HiveConfig() @@ -251,7 +255,9 @@ public void testExplicitPropertyMappings() .setDynamicFilteringProbeBlockingTimeout(new Duration(10, TimeUnit.SECONDS)) .setTimestampPrecision(HiveTimestampPrecision.NANOSECONDS) .setOptimizeSymlinkListing(false) - .setLegacyHiveViewTranslation(true); + .setLegacyHiveViewTranslation(true) + .setSizeBasedSplitWeightsEnabled(false) + .setMinimumAssignedSplitWeight(1.0); assertFullMapping(properties, expected); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java index e5b392faa4a7..427f35d64b4e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java @@ -190,7 +190,9 @@ protected QueryRunner createQueryRunner() .setHiveProperties(ImmutableMap.of( "hive.allow-register-partition-procedure", "true", // Reduce writer sort buffer size to ensure SortingFileWriter gets used - "hive.writer-sort-buffer-size", "1MB")) + "hive.writer-sort-buffer-size", "1MB", + // Make weighted split scheduling more conservative to avoid OOMs in test + "hive.minimum-assigned-split-weight", "0.5")) .setInitialTables(REQUIRED_TPCH_TABLES) .build(); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java index 619e7f33f804..9fa3f0ecad6c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java @@ -25,6 +25,7 @@ import io.trino.plugin.hive.metastore.HivePageSinkMetadata; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.SplitWeight; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSource; @@ -242,7 +243,8 @@ private static ConnectorPageSource createPageSource(HiveTransactionHandle transa Optional.empty(), false, Optional.empty(), - 0); + 0, + SplitWeight.standard()); ConnectorTableHandle table = new HiveTableHandle(SCHEMA_NAME, TABLE_NAME, ImmutableMap.of(), ImmutableList.of(), ImmutableList.of(), Optional.empty()); HivePageSourceProvider provider = new HivePageSourceProvider( TYPE_MANAGER, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveSplit.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveSplit.java index e6712b48d0c3..fdef347215a3 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveSplit.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveSplit.java @@ -21,6 +21,7 @@ import io.trino.plugin.base.TypeDeserializer; import io.trino.plugin.hive.HiveColumnHandle.ColumnType; import io.trino.spi.HostAddress; +import io.trino.spi.SplitWeight; import io.trino.spi.type.TestingTypeManager; import io.trino.spi.type.Type; import org.apache.hadoop.fs.Path; @@ -82,7 +83,8 @@ public void testJsonRoundTrip() Optional.empty(), false, Optional.of(acidInfo), - 555534); + 555534, + SplitWeight.fromProportion(2.0)); // some non-standard value String json = codec.toJson(expected); HiveSplit actual = codec.fromJson(json); @@ -104,5 +106,6 @@ public void testJsonRoundTrip() assertEquals(actual.isS3SelectPushdownEnabled(), expected.isS3SelectPushdownEnabled()); assertEquals(actual.getAcidInfo().get(), expected.getAcidInfo().get()); assertEquals(actual.getSplitNumber(), expected.getSplitNumber()); + assertEquals(actual.getSplitWeight(), expected.getSplitWeight()); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java index b1dd33262922..6266fd5c472f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java @@ -24,6 +24,7 @@ import io.trino.plugin.hive.orc.OrcWriterConfig; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.ParquetWriterConfig; +import io.trino.spi.SplitWeight; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; @@ -110,7 +111,8 @@ private static ConnectorPageSource createTestingPageSource(HiveTransactionHandle Optional.empty(), false, Optional.empty(), - 0); + 0, + SplitWeight.standard()); TableHandle tableHandle = new TableHandle( new CatalogName(HIVE_CATALOG_NAME), diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/AbstractFileFormat.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/AbstractFileFormat.java index cf48e225641e..1b51d4b22669 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/AbstractFileFormat.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/AbstractFileFormat.java @@ -31,6 +31,7 @@ import io.trino.plugin.hive.HiveTypeName; import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.TableToPartitionMapping; +import io.trino.spi.SplitWeight; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; @@ -153,7 +154,8 @@ public ConnectorPageSource createGenericReader( Optional.empty(), false, Optional.empty(), - 0); + 0, + SplitWeight.standard()); return factory.createPageSource( TestingConnectorTransactionHandle.INSTANCE, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestSizeBasedSplitWeightProvider.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestSizeBasedSplitWeightProvider.java new file mode 100644 index 000000000000..29acc684a226 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestSizeBasedSplitWeightProvider.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.util; + +import io.airlift.units.DataSize; +import io.trino.spi.SplitWeight; +import org.testng.annotations.Test; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static org.testng.Assert.assertEquals; + +public class TestSizeBasedSplitWeightProvider +{ + @Test + public void testSimpleProportions() + { + SizeBasedSplitWeightProvider provider = new SizeBasedSplitWeightProvider(0.01, DataSize.of(64, MEGABYTE)); + assertEquals(provider.weightForSplitSizeInBytes(DataSize.of(64, MEGABYTE).toBytes()), SplitWeight.fromProportion(1)); + assertEquals(provider.weightForSplitSizeInBytes(DataSize.of(32, MEGABYTE).toBytes()), SplitWeight.fromProportion(0.5)); + assertEquals(provider.weightForSplitSizeInBytes(DataSize.of(16, MEGABYTE).toBytes()), SplitWeight.fromProportion(0.25)); + } + + @Test + public void testMinimumAndMaximumSplitWeightHandling() + { + double minimumWeight = 0.05; + DataSize targetSplitSize = DataSize.of(64, MEGABYTE); + SizeBasedSplitWeightProvider provider = new SizeBasedSplitWeightProvider(minimumWeight, targetSplitSize); + assertEquals(provider.weightForSplitSizeInBytes(1), SplitWeight.fromProportion(minimumWeight)); + + DataSize largerThanTarget = DataSize.of(128, MEGABYTE); + assertEquals(provider.weightForSplitSizeInBytes(largerThanTarget.toBytes()), SplitWeight.standard()); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "^minimumWeight must be > 0 and <= 1, found: 1\\.01$") + public void testInvalidMinimumWeight() + { + new SizeBasedSplitWeightProvider(1.01, DataSize.of(64, MEGABYTE)); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "^targetSplitSize must be > 0, found:.*$") + public void testInvalidTargetSplitSize() + { + new SizeBasedSplitWeightProvider(0.01, DataSize.ofBytes(0)); + } +}