Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 56 additions & 28 deletions core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntConsumer;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;
Expand All @@ -47,9 +48,9 @@ public void addTask(InternalNode node, RemoteTask task)
createOrGetNodeTasks(node).addTask(task);
}

public int getPartitionedSplitsOnNode(InternalNode node)
public PartitionedSplitsInfo getPartitionedSplitsOnNode(InternalNode node)
{
return createOrGetNodeTasks(node).getPartitionedSplitCount();
return createOrGetNodeTasks(node).getPartitionedSplitsInfo();
}

public PartitionedSplitCountTracker createPartitionedSplitCountTracker(InternalNode node, TaskId taskId)
Expand Down Expand Up @@ -80,16 +81,17 @@ private static class NodeTasks
{
private final Set<RemoteTask> remoteTasks = Sets.newConcurrentHashSet();
private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger();
private final AtomicLong nodeTotalPartitionedSplitWeight = new AtomicLong();
private final FinalizerService finalizerService;

public NodeTasks(FinalizerService finalizerService)
{
this.finalizerService = requireNonNull(finalizerService, "finalizerService is null");
}

private int getPartitionedSplitCount()
private PartitionedSplitsInfo getPartitionedSplitsInfo()
{
return nodeTotalPartitionedSplitCount.get();
return PartitionedSplitsInfo.forSplitCountAndWeightSum(nodeTotalPartitionedSplitCount.get(), nodeTotalPartitionedSplitWeight.get());
}

private void addTask(RemoteTask task)
Expand All @@ -112,8 +114,8 @@ public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId ta
{
requireNonNull(taskId, "taskId is null");

TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId);
PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker::setPartitionedSplitCount);
TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId, nodeTotalPartitionedSplitCount, nodeTotalPartitionedSplitWeight);
PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker);

// when partitionedSplitCountTracker is garbage collected, run the cleanup method on the tracker
// Note: tracker cannot have a reference to partitionedSplitCountTracker
Expand All @@ -123,41 +125,66 @@ public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId ta
}

@ThreadSafe
private class TaskPartitionedSplitCountTracker
private static class TaskPartitionedSplitCountTracker
implements Consumer<PartitionedSplitsInfo>
{
private final TaskId taskId;
private final AtomicInteger nodeTotalPartitionedSplitCount;
private final AtomicLong nodeTotalPartitionedSplitWeight;
private final AtomicInteger localPartitionedSplitCount = new AtomicInteger();
private final AtomicLong localPartitionedSplitWeight = new AtomicLong();

public TaskPartitionedSplitCountTracker(TaskId taskId)
public TaskPartitionedSplitCountTracker(TaskId taskId, AtomicInteger nodeTotalPartitionedSplitCount, AtomicLong nodeTotalPartitionedSplitWeight)
{
this.taskId = requireNonNull(taskId, "taskId is null");
this.nodeTotalPartitionedSplitCount = requireNonNull(nodeTotalPartitionedSplitCount, "nodeTotalPartitionedSplitCount is null");
this.nodeTotalPartitionedSplitWeight = requireNonNull(nodeTotalPartitionedSplitWeight, "nodeTotalPartitionedSplitWeight is null");
}

public synchronized void setPartitionedSplitCount(int partitionedSplitCount)
@Override
public synchronized void accept(PartitionedSplitsInfo partitionedSplits)
{
if (partitionedSplitCount < 0) {
int oldValue = localPartitionedSplitCount.getAndSet(0);
nodeTotalPartitionedSplitCount.addAndGet(-oldValue);
throw new IllegalArgumentException("partitionedSplitCount is negative");
if (partitionedSplits == null || partitionedSplits.getCount() < 0 || partitionedSplits.getWeightSum() < 0) {
clearLocalSplitInfo(false);
requireNonNull(partitionedSplits, "partitionedSplits is null"); // throw NPE if null, otherwise negative value
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird that this method would have side effects (clearLocalSplitInfo()) if the preconditions are not met (partitionedSplits != null & partitionedSplits.XXX >= 0)

What's the purpose of calling clearLocalSplitInfo under those conditions instead of just bailing out?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TaskPartitionedSplitCountTracker is sort of a specific sub-view of the global NodeTaskMap total split state per worker, so when we're going to bail out and throw an exception we want to undo the effect of any task-specific values on the global state first. The existing logic was doing the the same thing when / if a negative value was encountered for localPartitionedSplitCount, now it's been extended to consider negative weights or null arguments too.

throw new IllegalArgumentException("Invalid negative value: " + partitionedSplits);
}

int oldValue = localPartitionedSplitCount.getAndSet(partitionedSplitCount);
nodeTotalPartitionedSplitCount.addAndGet(partitionedSplitCount - oldValue);
int newCount = partitionedSplits.getCount();
long newWeight = partitionedSplits.getWeightSum();
int countDelta = newCount - localPartitionedSplitCount.getAndSet(newCount);
long weightDelta = newWeight - localPartitionedSplitWeight.getAndSet(newWeight);
if (countDelta != 0) {
nodeTotalPartitionedSplitCount.addAndGet(countDelta);
}
if (weightDelta != 0) {
nodeTotalPartitionedSplitWeight.addAndGet(weightDelta);
}
}

public void cleanup()
private void clearLocalSplitInfo(boolean reportAsLeaked)
{
int leakedSplits = localPartitionedSplitCount.getAndSet(0);
if (leakedSplits == 0) {
int leakedCount = localPartitionedSplitCount.getAndSet(0);
long leakedWeight = localPartitionedSplitWeight.getAndSet(0);
if (leakedCount == 0 && leakedWeight == 0) {
return;
}

log.error("BUG! %s for %s leaked with %s partitioned splits. Cleaning up so server can continue to function.",
getClass().getName(),
taskId,
leakedSplits);
if (reportAsLeaked) {
log.error("BUG! %s for %s leaked with %s partitioned splits (weight: %s). Cleaning up so server can continue to function.",
getClass().getName(),
taskId,
leakedCount,
leakedWeight);
}

nodeTotalPartitionedSplitCount.addAndGet(-leakedSplits);
nodeTotalPartitionedSplitCount.addAndGet(-leakedCount);
nodeTotalPartitionedSplitWeight.addAndGet(-leakedWeight);
}

public void cleanup()
{
clearLocalSplitInfo(true);
}

@Override
Expand All @@ -166,23 +193,24 @@ public String toString()
return toStringHelper(this)
.add("taskId", taskId)
.add("splits", localPartitionedSplitCount)
.add("weight", localPartitionedSplitWeight)
.toString();
}
}
}

public static class PartitionedSplitCountTracker
{
private final IntConsumer splitSetter;
private final Consumer<PartitionedSplitsInfo> splitSetter;

public PartitionedSplitCountTracker(IntConsumer splitSetter)
public PartitionedSplitCountTracker(Consumer<PartitionedSplitsInfo> splitSetter)
{
this.splitSetter = requireNonNull(splitSetter, "splitSetter is null");
}

public void setPartitionedSplitCount(int partitionedSplitCount)
public void setPartitionedSplits(PartitionedSplitsInfo partitionedSplits)
{
splitSetter.accept(partitionedSplitCount);
splitSetter.accept(partitionedSplits);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution;

import static com.google.common.base.MoreObjects.toStringHelper;

public final class PartitionedSplitsInfo
{
private static final PartitionedSplitsInfo NO_SPLITS_INFO = new PartitionedSplitsInfo(0, 0);

private final int count;
private final long weightSum;

private PartitionedSplitsInfo(int splitCount, long splitsWeightSum)
{
this.count = splitCount;
this.weightSum = splitsWeightSum;
}

public int getCount()
{
return count;
}

public long getWeightSum()
{
return weightSum;
}

@Override
public int hashCode()
{
return (count * 31) + Long.hashCode(weightSum);
Comment thread
pettyjamesm marked this conversation as resolved.
Outdated
}

@Override
public boolean equals(Object other)
{
if (!(other instanceof PartitionedSplitsInfo)) {
return false;
}
PartitionedSplitsInfo otherInfo = (PartitionedSplitsInfo) other;
return this == otherInfo || (this.count == otherInfo.count && this.weightSum == otherInfo.weightSum);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("count", count)
.add("weightSum", weightSum)
.toString();
}

public static PartitionedSplitsInfo forSplitCountAndWeightSum(int splitCount, long weightSum)
{
// Avoid allocating for the "no splits" case, also mask potential race condition between
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the race condition because we track count and weight separately? We should consider wrapping them in a holder object that gets updated atomically -- or just make the update to those fields atomic via synchronized. Reasoning about correctness under race conditions is hard.

Copy link
Copy Markdown
Member Author

@pettyjamesm pettyjamesm Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The potential for a partially inconsistent view on this mostly comes from the worker side PipelineContext which increments / decrements AtomicLongs separately and also traverses non-snapshotable state as part of PipelineContext#getPipelineStatus. In that scenario, the race is basically unavoidable because you can't snapshot the whole pipeline state to get a consistent view without a huge locked region and significant performance risk.

In this case, we know that minor data races are possible but essentially benign. If what you observe in that snapshot is count=0, weight > 0, you know that with a tiny change to the timing you could also have observed either count=0, weight=0 or count=n, weight > 0). In that situation, choosing to discard the weight seemed preferable to synthesizing a fake split count to simulate the latter scenario.

Incidentally, the global NodeTaskMap can have similar inconsistent snapshot states which you could address by synchronizing updates, but that would also create a potentially concerning lock contention bottleneck since every task update and split assignment operation would need to serialize on that lock.

// count and weight updates that might yield a positive weight with a count of 0
return splitCount == 0 ? NO_SPLITS_INFO : new PartitionedSplitsInfo(splitCount, weightSum);
}

public static PartitionedSplitsInfo forZeroSplits()
{
return NO_SPLITS_INFO;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ public interface RemoteTask
*/
void addFinalTaskInfoListener(StateChangeListener<TaskInfo> stateChangeListener);

ListenableFuture<Void> whenSplitQueueHasSpace(int threshold);
ListenableFuture<Void> whenSplitQueueHasSpace(long weightThreshold);

void cancel();

void abort();

int getPartitionedSplitCount();
PartitionedSplitsInfo getPartitionedSplitsInfo();

int getQueuedPartitionedSplitCount();
PartitionedSplitsInfo getQueuedPartitionedSplitsInfo();

int getUnacknowledgedPartitionedSplitCount();
}
10 changes: 9 additions & 1 deletion core/trino-main/src/main/java/io/trino/execution/SqlTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder)
}

int queuedPartitionedDrivers = 0;
long queuedPartitionedSplitsWeight = 0L;
int runningPartitionedDrivers = 0;
long runningPartitionedSplitsWeight = 0L;
DataSize physicalWrittenDataSize = DataSize.ofBytes(0);
DataSize userMemoryReservation = DataSize.ofBytes(0);
DataSize systemMemoryReservation = DataSize.ofBytes(0);
Expand All @@ -294,7 +296,9 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder)
TaskInfo taskInfo = taskHolder.getFinalTaskInfo();
TaskStats taskStats = taskInfo.getStats();
queuedPartitionedDrivers = taskStats.getQueuedPartitionedDrivers();
queuedPartitionedSplitsWeight = taskStats.getQueuedPartitionedSplitsWeight();
runningPartitionedDrivers = taskStats.getRunningPartitionedDrivers();
runningPartitionedSplitsWeight = taskStats.getRunningPartitionedSplitsWeight();
physicalWrittenDataSize = taskStats.getPhysicalWrittenDataSize();
userMemoryReservation = taskStats.getUserMemoryReservation();
systemMemoryReservation = taskStats.getSystemMemoryReservation();
Expand All @@ -308,7 +312,9 @@ else if (taskHolder.getTaskExecution() != null) {
for (PipelineContext pipelineContext : taskContext.getPipelineContexts()) {
PipelineStatus pipelineStatus = pipelineContext.getPipelineStatus();
queuedPartitionedDrivers += pipelineStatus.getQueuedPartitionedDrivers();
queuedPartitionedSplitsWeight += pipelineStatus.getQueuedPartitionedSplitsWeight();
runningPartitionedDrivers += pipelineStatus.getRunningPartitionedDrivers();
runningPartitionedSplitsWeight += pipelineStatus.getRunningPartitionedSplitsWeight();
physicalWrittenBytes += pipelineContext.getPhysicalWrittenDataSize();
}
physicalWrittenDataSize = succinctBytes(physicalWrittenBytes);
Expand Down Expand Up @@ -338,7 +344,9 @@ else if (taskHolder.getTaskExecution() != null) {
revocableMemoryReservation,
fullGcCount,
fullGcTime,
dynamicFiltersVersion);
dynamicFiltersVersion,
queuedPartitionedSplitsWeight,
runningPartitionedSplitsWeight);
}

private TaskStats getTaskStats(TaskHolder taskHolder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.trino.operator.PipelineExecutionStrategy;
import io.trino.operator.StageExecutionDescriptor;
import io.trino.operator.TaskContext;
import io.trino.spi.SplitWeight;
import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import io.trino.sql.planner.plan.PlanNodeId;

Expand Down Expand Up @@ -371,7 +372,7 @@ private void mergeIntoPendingSplits(PlanNodeId planNodeId, Set<ScheduledSplit> s
DriverSplitRunnerFactory partitionedDriverFactory = driverRunnerFactoriesWithSplitLifeCycle.get(planNodeId);
PendingSplitsForPlanNode pendingSplitsForPlanNode = pendingSplitsByPlanNode.get(planNodeId);

partitionedDriverFactory.splitsAdded(scheduledSplits.size());
partitionedDriverFactory.splitsAdded(scheduledSplits.size(), SplitWeight.rawValueSum(scheduledSplits, scheduledSplit -> scheduledSplit.getSplit().getSplitWeight()));
for (ScheduledSplit scheduledSplit : scheduledSplits) {
Lifespan lifespan = scheduledSplit.getSplit().getLifespan();
checkLifespan(partitionedDriverFactory.getPipelineExecutionStrategy(), lifespan);
Expand Down Expand Up @@ -933,7 +934,8 @@ public DriverSplitRunner createDriverRunner(@Nullable ScheduledSplit partitioned
status.incrementPendingCreation(pipelineContext.getPipelineId(), lifespan);
// create driver context immediately so the driver existence is recorded in the stats
// the number of drivers is used to balance work across nodes
DriverContext driverContext = pipelineContext.addDriverContext(lifespan);
long splitWeight = partitionedSplit == null ? 0 : partitionedSplit.getSplit().getSplitWeight().getRawValue();
DriverContext driverContext = pipelineContext.addDriverContext(lifespan, splitWeight);
return new DriverSplitRunner(this, driverContext, partitionedSplit, lifespan);
}

Expand Down Expand Up @@ -1003,9 +1005,9 @@ public OptionalInt getDriverInstances()
return driverFactory.getDriverInstances();
}

public void splitsAdded(int count)
public void splitsAdded(int count, long weightSum)
{
pipelineContext.splitsAdded(count);
pipelineContext.splitsAdded(count, weightSum);
}
}

Expand Down
Loading