Skip to content

Commit 77bcd77

Browse files
Aaditya Rameshzsxwing
authored andcommitted
[SPARK-19525][CORE] Add RDD checkpoint compression support
## What changes were proposed in this pull request? This PR adds RDD checkpoint compression support and add a new config `spark.checkpoint.compress` to enable/disable it. Credit goes to aramesh117 Closes #17024 ## How was this patch tested? The new unit test. Author: Shixiong Zhu <[email protected]> Author: Aaditya Ramesh <[email protected]> Closes #17789 from zsxwing/pr17024.
1 parent ebff519 commit 77bcd77

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,10 @@ package object config {
272272
.booleanConf
273273
.createWithDefault(false)
274274

275+
private[spark] val CHECKPOINT_COMPRESS =
276+
ConfigBuilder("spark.checkpoint.compress")
277+
.doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
278+
"spark.io.compression.codec.")
279+
.booleanConf
280+
.createWithDefault(false)
275281
}

core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.rdd
1919

2020
import java.io.{FileNotFoundException, IOException}
21+
import java.util.concurrent.TimeUnit
2122

2223
import scala.reflect.ClassTag
2324
import scala.util.control.NonFatal
@@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path
2728
import org.apache.spark._
2829
import org.apache.spark.broadcast.Broadcast
2930
import org.apache.spark.internal.Logging
31+
import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
32+
import org.apache.spark.io.CompressionCodec
3033
import org.apache.spark.util.{SerializableConfiguration, Utils}
3134

3235
/**
@@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
119122
originalRDD: RDD[T],
120123
checkpointDir: String,
121124
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
125+
val checkpointStartTimeNs = System.nanoTime()
122126

123127
val sc = originalRDD.sparkContext
124128

@@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging {
140144
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
141145
}
142146

147+
val checkpointDurationMs =
148+
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
149+
logInfo(s"Checkpointing took $checkpointDurationMs ms.")
150+
143151
val newRDD = new ReliableCheckpointRDD[T](
144152
sc, checkpointDirPath.toString, originalRDD.partitioner)
145153
if (newRDD.partitions.length != originalRDD.partitions.length) {
@@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging {
169177
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
170178

171179
val fileOutputStream = if (blockSize < 0) {
172-
fs.create(tempOutputPath, false, bufferSize)
180+
val fileStream = fs.create(tempOutputPath, false, bufferSize)
181+
if (env.conf.get(CHECKPOINT_COMPRESS)) {
182+
CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
183+
} else {
184+
fileStream
185+
}
173186
} else {
174187
// This is mainly for testing purpose
175188
fs.create(tempOutputPath, false, bufferSize,
@@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
273286
val env = SparkEnv.get
274287
val fs = path.getFileSystem(broadcastedConf.value.value)
275288
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
276-
val fileInputStream = fs.open(path, bufferSize)
289+
val fileInputStream = {
290+
val fileStream = fs.open(path, bufferSize)
291+
if (env.conf.get(CHECKPOINT_COMPRESS)) {
292+
CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
293+
} else {
294+
fileStream
295+
}
296+
}
277297
val serializer = env.serializer.newInstance()
278298
val deserializeStream = serializer.deserializeStream(fileInputStream)
279299

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ import java.io.File
2121

2222
import scala.reflect.ClassTag
2323

24+
import com.google.common.io.ByteStreams
2425
import org.apache.hadoop.fs.Path
2526

27+
import org.apache.spark.io.CompressionCodec
2628
import org.apache.spark.rdd._
2729
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
2830
import org.apache.spark.util.Utils
@@ -580,3 +582,42 @@ object CheckpointSuite {
580582
).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
581583
}
582584
}
585+
586+
class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
587+
588+
test("checkpoint compression") {
589+
val checkpointDir = Utils.createTempDir()
590+
try {
591+
val conf = new SparkConf()
592+
.set("spark.checkpoint.compress", "true")
593+
.set("spark.ui.enabled", "false")
594+
sc = new SparkContext("local", "test", conf)
595+
sc.setCheckpointDir(checkpointDir.toString)
596+
val rdd = sc.makeRDD(1 to 20, numSlices = 1)
597+
rdd.checkpoint()
598+
assert(rdd.collect().toSeq === (1 to 20))
599+
600+
// Verify that RDD is checkpointed
601+
assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
602+
603+
val checkpointPath = new Path(rdd.getCheckpointFile.get)
604+
val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
605+
val checkpointFile =
606+
fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get
607+
608+
// Verify the checkpoint file is compressed, in other words, can be decompressed
609+
val compressedInputStream = CompressionCodec.createCodec(conf)
610+
.compressedInputStream(fs.open(checkpointFile))
611+
try {
612+
ByteStreams.toByteArray(compressedInputStream)
613+
} finally {
614+
compressedInputStream.close()
615+
}
616+
617+
// Verify that the compressed content can be read back
618+
assert(rdd.collect().toSeq === (1 to 20))
619+
} finally {
620+
Utils.deleteRecursively(checkpointDir)
621+
}
622+
}
623+
}

0 commit comments

Comments
 (0)