Skip to content
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4cd559f
compute checksum for shuffle
JiexingLi Mar 10, 2025
53d11af
faster checksum
JiexingLi Mar 12, 2025
c7675b1
default on
JiexingLi Mar 14, 2025
7b89c44
fix compile error
JiexingLi Mar 14, 2025
64dd36b
add contructor
JiexingLi Mar 14, 2025
422e370
address comments
JiexingLi Apr 21, 2025
89901ca
move config
JiexingLi Apr 21, 2025
a1c50fa
address comments
JiexingLi Apr 23, 2025
db59634
add license headers
JiexingLi Apr 24, 2025
04e08eb
address comments
ivoson Aug 26, 2025
c9c28e6
address comments
ivoson Aug 26, 2025
d82bad2
Merge branch 'master' into shuffle-checksum
ivoson Aug 26, 2025
2575d52
address comments
ivoson Sep 1, 2025
3b99edb
fix code stype issue
ivoson Sep 2, 2025
74266a5
debug flaky ut
ivoson Sep 3, 2025
22c79c8
Revert "debug flaky ut"
ivoson Sep 3, 2025
df48158
to resolve conclits
ivoson Sep 3, 2025
cf28940
Merge branch 'apache:master' into shuffle-checksum
ivoson Sep 3, 2025
4cfaac8
fix ut
ivoson Sep 3, 2025
786fdd3
address comments
ivoson Sep 4, 2025
602729c
address comments
ivoson Sep 4, 2025
137f254
Update core/src/main/scala/org/apache/spark/MapOutputTracker.scala
ivoson Sep 5, 2025
dde16d4
Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSui…
ivoson Sep 5, 2025
f7d9dfa
Update core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleW…
ivoson Sep 5, 2025
5aabe70
Update core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeS…
ivoson Sep 5, 2025
2fd0a94
Update core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffle…
ivoson Sep 5, 2025
bbe26bf
address comments
ivoson Sep 5, 2025
1a8e9f7
Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSui…
cloud-fan Sep 5, 2025
97af717
fix ut
ivoson Sep 5, 2025
5e01c52
address comments
ivoson Sep 8, 2025
ce29311
fix mima test failure
ivoson Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.checksum

import scala.util.control.NonFatal

import org.apache.spark.internal.Logging

/**
* A class for computing checksum for input (key, value) pairs. The checksum is independent of
* the order of the input (key, value) pairs. It is done by computing a checksum for each row
* first, then computing the XOR and SUM for all the row checksums and mixing these two values
* as the final checksum.
*/
abstract class RowBasedChecksum() extends Serializable with Logging {
private val ROTATE_POSITIONS = 27
private var hasError: Boolean = false
private var checksumXor: Long = 0
private var checksumSum: Long = 0

/**
* Returns the checksum value. It returns the default checksum value (0) if there
* are any errors encountered during the checksum computation.
*/
def getValue: Long = {
if (!hasError) {
// Here we rotate the `checksumSum` to transforms these two values into a single, strong
// composite checksum by ensuring their bit patterns are thoroughly mixed.
checksumXor ^ rotateLeft(checksumSum)
} else {
0
}
}

/** Updates the row-based checksum with the given (key, value) pair. Not thread safe. */
def update(key: Any, value: Any): Unit = {
if (!hasError) {
try {
val rowChecksumValue = calculateRowChecksum(key, value)
checksumXor = checksumXor ^ rowChecksumValue
checksumSum += rowChecksumValue
} catch {
case NonFatal(e) =>
logError("Checksum computation encountered error: ", e)
hasError = true
}
}
}

/** Computes and returns the checksum value for the given (key, value) pair */
protected def calculateRowChecksum(key: Any, value: Any): Long

// Rotate the value by shifting the bits by `ROTATE_POSITIONS` positions to the left.
private def rotateLeft(value: Long): Long = {
(value << ROTATE_POSITIONS) | (value >>> (64 - ROTATE_POSITIONS))
}
}

object RowBasedChecksum {
def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = {
Option(rowBasedChecksums)
.map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue))
.getOrElse(0L)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import scala.Tuple2;
import scala.collection.Iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;

import org.apache.spark.internal.SparkLogger;
Expand All @@ -53,6 +54,7 @@
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
Expand Down Expand Up @@ -104,6 +106,13 @@ final class BypassMergeSortShuffleWriter<K, V>
private long[] partitionLengths;
/** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
private final Checksum[] partitionChecksums;
/**
* Checksum calculator for each partition. Different from the above Checksum,
* RowBasedChecksum is independent of the input row order, which is used to
* detect whether different task attempts of the same partition produce different
* output data or not.
*/
private final RowBasedChecksum[] rowBasedChecksums;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
Expand Down Expand Up @@ -132,6 +141,7 @@ final class BypassMergeSortShuffleWriter<K, V>
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
this.rowBasedChecksums = dep.rowBasedChecksums();
}

@Override
Expand All @@ -144,7 +154,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
partitionLengths = mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand All @@ -171,7 +181,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
final int partitionId = partitioner.getPartition(key);
partitionWriters[partitionId].write(key, record._2());
if (rowBasedChecksums.length > 0) {
rowBasedChecksums[partitionId].update(key, record._2());
}
}

for (int i = 0; i < numPartitions; i++) {
Expand All @@ -182,7 +196,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {

partitionLengths = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
Expand All @@ -199,6 +213,17 @@ public long[] getPartitionLengths() {
return partitionLengths;
}

// For test only.
@VisibleForTesting
RowBasedChecksum[] getRowBasedChecksums() {
return rowBasedChecksums;
}

@VisibleForTesting
long getAggregatedChecksumValue() {
return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
}

/**
* Concatenate all of the per-partition files into a single combined file.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.ExposedBufferByteArrayOutputStream;
import org.apache.spark.util.Utils;

@Private
Expand Down Expand Up @@ -93,15 +95,16 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Nullable private long[] partitionLengths;
private long peakMemoryUsedBytes = 0;

/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
MyByteArrayOutputStream(int size) { super(size); }
public byte[] getBuf() { return buf; }
}

private MyByteArrayOutputStream serBuffer;
private ExposedBufferByteArrayOutputStream serBuffer;
private SerializationStream serOutputStream;

/**
* RowBasedChecksum calculator for each partition. RowBasedChecksum is independent
* of the input row order, which is used to detect whether different task attempts
* of the same partition produce different output data or not.
*/
private final RowBasedChecksum[] rowBasedChecksums;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
Expand Down Expand Up @@ -141,6 +144,7 @@ public UnsafeShuffleWriter(
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.mergeBufferSizeInBytes =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024;
this.rowBasedChecksums = dep.rowBasedChecksums();
open();
}

Expand All @@ -162,6 +166,17 @@ public long getPeakMemoryUsedBytes() {
return peakMemoryUsedBytes;
}

// For test only.
@VisibleForTesting
RowBasedChecksum[] getRowBasedChecksums() {
return rowBasedChecksums;
}

@VisibleForTesting
long getAggregatedChecksumValue() {
return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
}

/**
* This convenience method should only be called in test code.
*/
Expand Down Expand Up @@ -210,7 +225,7 @@ private void open() throws SparkException {
partitioner.numPartitions(),
sparkConf,
writeMetrics);
serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
serBuffer = new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
serOutputStream = serializer.serializeStream(serBuffer);
}

Expand All @@ -233,7 +248,7 @@ void closeAndWriteOutput() throws IOException {
}
}
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
}

@VisibleForTesting
Expand All @@ -251,6 +266,9 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException {

sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
if (rowBasedChecksums.length > 0) {
rowBasedChecksums[partitionId].update(key, record._2());
}
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util;

import java.io.ByteArrayOutputStream;

/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
public final class ExposedBufferByteArrayOutputStream extends ByteArrayOutputStream {
public ExposedBufferByteArrayOutputStream(int size) { super(size); }
public byte[] getBuf() { return buf; }
}
28 changes: 27 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
import org.apache.spark.shuffle.checksum.RowBasedChecksum
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -59,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
override def rdd: RDD[T] = _rdd
}

object ShuffleDependency {
private val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can make this private[spark] and use it in other places within this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

/**
* :: DeveloperApi ::
Expand All @@ -74,6 +78,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
* @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask
* @param rowBasedChecksums the row-based checksums for each shuffle partition
*/
@DeveloperApi
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
Expand All @@ -83,9 +88,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
extends Dependency[Product2[K, V]] with Logging {

def this(
rdd: RDD[_ <: Product2[K, V]],
partitioner: Partitioner,
serializer: Serializer,
keyOrdering: Option[Ordering[K]],
aggregator: Option[Aggregator[K, V, C]],
mapSideCombine: Boolean,
shuffleWriterProcessor: ShuffleWriteProcessor) = {
this(
rdd,
partitioner,
serializer,
keyOrdering,
aggregator,
mapSideCombine,
shuffleWriterProcessor,
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS
)
}

if (mapSideCombine) {
require(aggregator.isDefined, "Map-side combine without Aggregator specified!")
}
Expand Down
17 changes: 16 additions & 1 deletion core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection
import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.collection.mutable.{HashMap, ListBuffer, Map, Set}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -99,6 +99,12 @@ private class ShuffleStatus(
*/
val mapStatusesDeleted = new Array[MapStatus](numPartitions)

/**
* Keep the indices of the Map tasks whose checksums are different across retries.
* Exposed for testing.
*/
private[spark] val checksumMismatchIndices : Set[Int] = Set()

/**
* MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the
* array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for
Expand Down Expand Up @@ -169,6 +175,15 @@ private class ShuffleStatus(
} else {
mapIdToMapIndex.remove(currentMapStatus.mapId)
}
logDebug(s"Checksum of map output for task ${status.mapId} is ${status.checksumValue}")

val preStatus =
if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex)
if (preStatus != null && preStatus.checksumValue != status.checksumValue) {
logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " +
s"${status.checksumValue} for task ${status.mapId}.")
checksumMismatchIndices.add(mapIndex)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three main cases here:

  • task reattempt due to stage reattempt after downstream stages have consumed output.
  • task reattempt due to stage reattempt before downstream stages have consumed output (missing partitions detected during stage attempt completion).
  • speculative tasks.

For the latter two, we dont need to track it in checksumMismatchIndices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for case 1, we need to track the mismatches. The usage of checksumMismatchIndices is that (in the next PR) we will rollback the downstream stages, if we detect checksum mismatches for its upstream stages.

For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here.

For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here.

It is unclear how checksumMismatchIndices will be used - as perhaps it might be fine to record it: but my query would be why record it at all ?
Is it due to complexity of detecting case (2) ?

For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.

That is fair, this is indeed possible.

mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
}
Expand Down
Loading