Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4cd559f
compute checksum for shuffle
JiexingLi Mar 10, 2025
53d11af
faster checksum
JiexingLi Mar 12, 2025
c7675b1
default on
JiexingLi Mar 14, 2025
7b89c44
fix compile error
JiexingLi Mar 14, 2025
64dd36b
add contructor
JiexingLi Mar 14, 2025
422e370
address comments
JiexingLi Apr 21, 2025
89901ca
move config
JiexingLi Apr 21, 2025
a1c50fa
address comments
JiexingLi Apr 23, 2025
db59634
add license headers
JiexingLi Apr 24, 2025
04e08eb
address comments
ivoson Aug 26, 2025
c9c28e6
address comments
ivoson Aug 26, 2025
d82bad2
Merge branch 'master' into shuffle-checksum
ivoson Aug 26, 2025
2575d52
address comments
ivoson Sep 1, 2025
3b99edb
fix code stype issue
ivoson Sep 2, 2025
74266a5
debug flaky ut
ivoson Sep 3, 2025
22c79c8
Revert "debug flaky ut"
ivoson Sep 3, 2025
df48158
to resolve conclits
ivoson Sep 3, 2025
cf28940
Merge branch 'apache:master' into shuffle-checksum
ivoson Sep 3, 2025
4cfaac8
fix ut
ivoson Sep 3, 2025
786fdd3
address comments
ivoson Sep 4, 2025
602729c
address comments
ivoson Sep 4, 2025
137f254
Update core/src/main/scala/org/apache/spark/MapOutputTracker.scala
ivoson Sep 5, 2025
dde16d4
Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSui…
ivoson Sep 5, 2025
f7d9dfa
Update core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleW…
ivoson Sep 5, 2025
5aabe70
Update core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeS…
ivoson Sep 5, 2025
2fd0a94
Update core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffle…
ivoson Sep 5, 2025
bbe26bf
address comments
ivoson Sep 5, 2025
1a8e9f7
Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSui…
cloud-fan Sep 5, 2025
97af717
fix ut
ivoson Sep 5, 2025
5e01c52
address comments
ivoson Sep 8, 2025
ce29311
fix mima test failure
ivoson Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.checksum

import java.io.ObjectOutputStream
import java.util.zip.Checksum

import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.util.MyByteArrayOutputStream

/**
* A class for computing checksum for input (key, value) pairs. The checksum is independent of
* the order of the input (key, value) pairs. It is done by computing a checksum for each row
* first, and then computing the XOR for all the row checksums.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the classdoc. We now also leverage the sum to handle duplicated values better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

*/
abstract class RowBasedChecksum() extends Serializable with Logging {
private var hasError: Boolean = false
private var checksumValue: Long = 0
/** Returns the checksum value computed. Tt returns the default checksum value (0) if there
* are any errors encountered during the checksum computation.
*/
def getValue: Long = {
if (!hasError) checksumValue else 0
}

/** Updates the row-based checksum with the given (key, value) pair */
def update(key: Any, value: Any): Unit = {
if (!hasError) {
try {
val rowChecksumValue = calculateRowChecksum(key, value)
checksumValue = checksumValue ^ rowChecksumValue
Copy link
Contributor

@peter-toth peter-toth May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XOR has problems when the same (key, value) pair is used multiple times. Should we track the number of pairs as well?

Copy link
Contributor

@attilapiros attilapiros May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good point but hard to compute as this will be a bit more stateful.

Copy link
Contributor

@attilapiros attilapiros May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about in addition to the bitwise XOR (currently checksumValue ) calculating a SUM as well and when the getValue is called combine those two into one number with an extra XOR (or just add together multiplying one with prime number)?

Copy link
Contributor

@peter-toth peter-toth May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I was referring to. But combining the number of pairs (count, not the sum) into the final checksum should be fine.
Update: No, combining just the count of pairs into the final checksum still has problems with duplicates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps something like this might work ? https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function

It will be more expensive than xor, but should handle order and duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm I think what we need is order insensitivity (within one partition the order of rows should not matter), fnv as I see is sensitive for the order

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, order sensitivity matter at beginning of task, not at end !

Sum + xor or sum + xor + multiplication with some xor folding to generate final hash might be cheap.
Can't think of other alternatives which might work well and yet is reasonably robust to duplication

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi folks, I am working with Jiexing to follow up on the PR.

Do you think something like below combining both the sum and xor of the hashcode would be helpful to address the concerns? cc @peter-toth @attilapiros @mridulm @cloud-fan

private var checksumValue: Long = 0
private var sum: Long = 0

def rotateLeft(value: Long, k: Int): Long = {
  ((value << k) & 0xffffffffffffffff) | (x >>> (64 - k))
}

def getValue: Long = {
  if (!hasError) {
    checksumValue ^ rotateLeft(sum, 27)
  } else {
    0
  }
}

def update(key: Any, value: Any): Unit = {
  ...
  val rowChecksumValue = calculateRowChecksum(key, value)
  checksumValue = checksumValue ^ rowChecksumValue
  sum += rowChecksumValue
  ...
}

} catch {
case NonFatal(e) =>
logError("Checksum computation encountered error: ", e)
hasError = true
}
}
}

/** Computes and returns the checksum value for the given (key, value) pair */
protected def calculateRowChecksum(key: Any, value: Any): Long
}

/**
* A Concrete implementation of RowBasedChecksum. The checksum for each row is
* computed by first converting the (key, value) pair to byte array using OutputStreams,
* and then computing the checksum for the byte array.
* Note that this checksum computation is very expensive, and it is used only in tests
* in the core component. A much cheaper implementation of RowBasedChecksum is in
* UnsafeRowChecksum.
*
* @param checksumAlgorithm the algorithm used for computing checksum.
*/
class OutputStreamRowBasedChecksum(checksumAlgorithm: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a question: I don't see this used anywhere except in tests, why not have it in core/src/test instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to test package.

extends RowBasedChecksum() {

private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024

@transient private lazy val serBuffer =
new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE)
@transient private lazy val objOut = new ObjectOutputStream(serBuffer)

@transient
protected lazy val checksum: Checksum =
ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)

override protected def calculateRowChecksum(key: Any, value: Any): Long = {
assert(checksum != null, "Checksum is null")

// Converts the (key, value) pair into byte array.
objOut.reset()
serBuffer.reset()
objOut.writeObject((key, value))
objOut.flush()
serBuffer.flush()

// Computes and returns the checksum for the byte array.
checksum.reset()
checksum.update(serBuffer.getBuf, 0, serBuffer.size())
checksum.getValue
}
}

object RowBasedChecksum {
def createPartitionRowBasedChecksums(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: a comment to say it is for testing only or better would be move to a helper class used in the tests only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe into the ShuffleChecksumTestHelper? But its name suggest it is only for shuffle checksum. So what about an extra rename to ChecksumTestHelper?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see the comment above:

 Note that this checksum computation is very expensive, and it is used only in tests
 in the core component. A much cheaper implementation of RowBasedChecksum is in
 UnsafeRowChecksum.

And I can see your comment:

I can't use UnsafeRowChecksum.scala in the test because the test is in core, while the usaferow is in sql. So I added OutputStreamRowBasedChecksum for the tests in core.

But you can move this class and object to the test code of the core module, is not it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you cannot move the whole object so let's just move the method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the method to ShuffleChecksumTestHelper, didn't rename the class as currently all the new added classes/components were in shuffle package.

numPartitions: Int,
checksumAlgorithm: String): Array[RowBasedChecksum] = {
Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm))
}

def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = {
Option(rowBasedChecksums)
.map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue))
.getOrElse(0L)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
Expand Down Expand Up @@ -104,6 +105,14 @@ final class BypassMergeSortShuffleWriter<K, V>
private long[] partitionLengths;
/** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
private final Checksum[] partitionChecksums;
/**
* Checksum calculator for each partition. Different from the above Checksum,
* RowBasedChecksum is independent of the input row order, which is used to
* detect whether different task attempts of the same partition produce different
* output data or not.
*/
private final RowBasedChecksum[] rowBasedChecksums;
private final SparkConf conf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conf is not needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
Expand Down Expand Up @@ -132,6 +141,8 @@ final class BypassMergeSortShuffleWriter<K, V>
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
this.rowBasedChecksums = dep.rowBasedChecksums();
this.conf = conf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: this conf is not needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

}

@Override
Expand All @@ -144,7 +155,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
partitionLengths = mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand All @@ -171,7 +182,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
final int partitionId = partitioner.getPartition(key);
partitionWriters[partitionId].write(key, record._2());
if (rowBasedChecksums.length > 0) {
rowBasedChecksums[partitionId].update(key, record._2());
}
}

for (int i = 0; i < numPartitions; i++) {
Expand All @@ -182,7 +197,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {

partitionLengths = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
Expand All @@ -199,6 +214,14 @@ public long[] getPartitionLengths() {
return partitionLengths;
}

public RowBasedChecksum[] getRowBasedChecksums() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: a comment to say it is for testing only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return rowBasedChecksums;
}

public long getAggregatedChecksumValue() {
return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
}

/**
* Concatenate all of the per-partition files into a single combined file.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.MyByteArrayOutputStream;
import org.apache.spark.util.Utils;

@Private
Expand Down Expand Up @@ -94,15 +96,16 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Nullable private long[] partitionLengths;
private long peakMemoryUsedBytes = 0;

/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
MyByteArrayOutputStream(int size) { super(size); }
public byte[] getBuf() { return buf; }
}

private MyByteArrayOutputStream serBuffer;
private SerializationStream serOutputStream;

/**
* RowBasedChecksum calculator for each partition. RowBasedChecksum is independent
* of the input row order, which is used to detect whether different task attempts
* of the same partition produce different output data or not.
*/
private final RowBasedChecksum[] rowBasedChecksums;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
Expand Down Expand Up @@ -142,6 +145,7 @@ public UnsafeShuffleWriter(
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.mergeBufferSizeInBytes =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024;
this.rowBasedChecksums = dep.rowBasedChecksums();
open();
}

Expand All @@ -163,6 +167,13 @@ public long getPeakMemoryUsedBytes() {
return peakMemoryUsedBytes;
}

public RowBasedChecksum[] getRowBasedChecksums() {
return rowBasedChecksums;
}
public long getAggregatedChecksumValue() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: missing new line to separate those methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
}

/**
* This convenience method should only be called in test code.
*/
Expand Down Expand Up @@ -234,7 +245,7 @@ void closeAndWriteOutput() throws IOException {
}
}
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
}

@VisibleForTesting
Expand All @@ -252,6 +263,9 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException {

sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
if (rowBasedChecksums.length > 0) {
rowBasedChecksums[partitionId].update(key, record._2());
}
}

@VisibleForTesting
Expand Down Expand Up @@ -330,7 +344,8 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
logger.debug("Using slow merge");
mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
}
partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
partitionLengths =
mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is this just an unnecessary change? If yes please revert it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

} catch (Exception e) {
try {
mapWriter.abort(e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util;

import java.io.ByteArrayOutputStream;

/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
public final class MyByteArrayOutputStream extends ByteArrayOutputStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename ? MyByteArrayOutputStream was fine when it was internal to the class.
Something like ExposedBufferByteArrayOutputStream or some such ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Updated

public MyByteArrayOutputStream(int size) { super(size); }
public byte[] getBuf() { return buf; }
}
28 changes: 27 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
import org.apache.spark.shuffle.checksum.RowBasedChecksum
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -59,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
override def rdd: RDD[T] = _rdd
}

object ShuffleDependency {
private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty
private val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

/**
* :: DeveloperApi ::
Expand All @@ -74,6 +78,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
* @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask
* @param rowBasedChecksums the row-based checksums for each shuffle partition
*/
@DeveloperApi
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
Expand All @@ -83,9 +88,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EmptyRowBasedChecksums)
extends Dependency[Product2[K, V]] with Logging {

def this(
rdd: RDD[_ <: Product2[K, V]],
partitioner: Partitioner,
serializer: Serializer,
keyOrdering: Option[Ordering[K]],
aggregator: Option[Aggregator[K, V, C]],
mapSideCombine: Boolean,
shuffleWriterProcessor: ShuffleWriteProcessor) = {
this(
rdd,
partitioner,
serializer,
keyOrdering,
aggregator,
mapSideCombine,
shuffleWriterProcessor,
Array.empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Array.empty
EMPTY_ROW_BASED_CHECKSUMS

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
}

if (mapSideCombine) {
require(aggregator.isDefined, "Map-side combine without Aggregator specified!")
}
Expand Down
13 changes: 12 additions & 1 deletion core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection
import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.collection.mutable.{HashMap, ListBuffer, Map, Set}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -99,6 +99,11 @@ private class ShuffleStatus(
*/
val mapStatusesDeleted = new Array[MapStatus](numPartitions)

/**
* Keep the indices of the Map tasks whose checksums are different across retries.
*/
private[this] val checksumMismatchIndices : Set[Int] = Set()

/**
* MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the
* array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for
Expand Down Expand Up @@ -169,6 +174,12 @@ private class ShuffleStatus(
} else {
mapIdToMapIndex.remove(currentMapStatus.mapId)
}

val preStatus =
if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex)
if (preStatus != null && preStatus.checksumValue != status.checksumValue) {
checksumMismatchIndices.add(mapIndex)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three main cases here:

  • task reattempt due to stage reattempt after downstream stages have consumed output.
  • task reattempt due to stage reattempt before downstream stages have consumed output (missing partitions detected during stage attempt completion).
  • speculative tasks.

For the latter two, we dont need to track it in checksumMismatchIndices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for case 1, we need to track the mismatches. The usage of checksumMismatchIndices is that (in the next PR) we will rollback the downstream stages, if we detect checksum mismatches for its upstream stages.

For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here.

For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here.

It is unclear how checksumMismatchIndices will be used - as perhaps it might be fine to record it: but my query would be why record it at all ?
Is it due to complexity of detecting case (2) ?

For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.

That is fair, this is indeed possible.

mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
}
Expand Down
Loading