Skip to content

Commit dab44ed

Browse files
committed
Finish the PR
1 parent 0a84178 commit dab44ed

File tree

3 files changed

+74
-120
lines changed

3 files changed

+74
-120
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,10 @@ package object config {
272272
.booleanConf
273273
.createWithDefault(false)
274274

275+
private[spark] val CHECKPOINT_COMPRESS =
276+
ConfigBuilder("spark.checkpoint.compress")
277+
.doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
278+
"spark.io.compression.codec.")
279+
.booleanConf
280+
.createWithDefault(false)
275281
}

core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.rdd
1919

20-
import java.io.{FileNotFoundException, InputStream, IOException, OutputStream}
20+
import java.io.{FileNotFoundException, IOException}
21+
import java.util.concurrent.TimeUnit
2122

2223
import scala.reflect.ClassTag
2324
import scala.util.control.NonFatal
@@ -27,11 +28,10 @@ import org.apache.hadoop.fs.Path
2728
import org.apache.spark._
2829
import org.apache.spark.broadcast.Broadcast
2930
import org.apache.spark.internal.Logging
31+
import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
3032
import org.apache.spark.io.CompressionCodec
3133
import org.apache.spark.util.{SerializableConfiguration, Utils}
3234

33-
34-
3535
/**
3636
* An RDD that reads from checkpoint files previously written to reliable storage.
3737
*/
@@ -122,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
122122
originalRDD: RDD[T],
123123
checkpointDir: String,
124124
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
125+
val checkpointStartTimeNs = System.nanoTime()
125126

126127
val sc = originalRDD.sparkContext
127128

@@ -136,18 +137,17 @@ private[spark] object ReliableCheckpointRDD extends Logging {
136137
val broadcastedConf = sc.broadcast(
137138
new SerializableConfiguration(sc.hadoopConfiguration))
138139
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
139-
val startTime = System.currentTimeMillis()
140140
sc.runJob(originalRDD,
141141
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
142142

143-
logInfo(s"Checkpointing took ${System.currentTimeMillis() - startTime} ms.")
144-
sc.conf.getOption("spark.checkpoint.compress.codec").foreach(codec => {
145-
logInfo(s"The checkpoint compression codec is $codec.")
146-
})
147143
if (originalRDD.partitioner.nonEmpty) {
148144
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
149145
}
150146

147+
val checkpointDurationMs =
148+
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
149+
logInfo(s"Checkpointing took $checkpointDurationMs ms.")
150+
151151
val newRDD = new ReliableCheckpointRDD[T](
152152
sc, checkpointDirPath.toString, originalRDD.partitioner)
153153
if (newRDD.partitions.length != originalRDD.partitions.length) {
@@ -164,7 +164,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
164164
def writePartitionToCheckpointFile[T: ClassTag](
165165
path: String,
166166
broadcastedConf: Broadcast[SerializableConfiguration],
167-
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]): Unit = {
167+
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
168168
val env = SparkEnv.get
169169
val outputDir = new Path(path)
170170
val fs = outputDir.getFileSystem(broadcastedConf.value.value)
@@ -177,13 +177,11 @@ private[spark] object ReliableCheckpointRDD extends Logging {
177177
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
178178

179179
val fileOutputStream = if (blockSize < 0) {
180-
lazy val fileStream: OutputStream = fs.create(tempOutputPath, false, bufferSize)
181-
env.conf.getOption("spark.checkpoint.compress.codec").fold(fileStream) {
182-
codec => {
183-
logDebug(s"Compressing using $codec.")
184-
CompressionCodec.createCodec(env.conf, codec)
185-
.compressedOutputStream(fileStream)
186-
}
180+
val fileStream = fs.create(tempOutputPath, false, bufferSize)
181+
if (env.conf.get(CHECKPOINT_COMPRESS)) {
182+
CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
183+
} else {
184+
fileStream
187185
}
188186
} else {
189187
// This is mainly for testing purpose
@@ -192,8 +190,6 @@ private[spark] object ReliableCheckpointRDD extends Logging {
192190
}
193191
val serializer = env.serializer.newInstance()
194192
val serializeStream = serializer.serializeStream(fileOutputStream)
195-
logTrace(s"Starting to write to checkpoint file $tempOutputPath.")
196-
val startTimeMs = System.currentTimeMillis()
197193
Utils.tryWithSafeFinally {
198194
serializeStream.writeAll(iterator)
199195
} {
@@ -214,7 +210,6 @@ private[spark] object ReliableCheckpointRDD extends Logging {
214210
}
215211
}
216212
}
217-
logInfo(s"Checkpointing took ${System.currentTimeMillis() - startTimeMs} ms.")
218213
}
219214

220215
/**
@@ -291,17 +286,16 @@ private[spark] object ReliableCheckpointRDD extends Logging {
291286
val env = SparkEnv.get
292287
val fs = path.getFileSystem(broadcastedConf.value.value)
293288
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
294-
lazy val fileStream: InputStream = fs.open(path, bufferSize)
295-
val inputStream: InputStream =
296-
env.conf.getOption("spark.checkpoint.compress.codec").fold(fileStream) {
297-
codec => {
298-
logDebug(s"Decompressing using $codec.")
299-
CompressionCodec.createCodec(env.conf, codec)
300-
.compressedInputStream(fileStream)
301-
}
289+
val fileInputStream = {
290+
val fileStream = fs.open(path, bufferSize)
291+
if (env.conf.get(CHECKPOINT_COMPRESS)) {
292+
CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
293+
} else {
294+
fileStream
302295
}
296+
}
303297
val serializer = env.serializer.newInstance()
304-
val deserializeStream = serializer.deserializeStream(inputStream)
298+
val deserializeStream = serializer.deserializeStream(fileInputStream)
305299

306300
// Register an on-task-completion callback to close the input stream.
307301
context.addTaskCompletionListener(context => deserializeStream.close())

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 46 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import org.apache.spark.rdd._
2929
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
3030
import org.apache.spark.util.Utils
3131

32-
3332
trait RDDCheckpointTester { self: SparkFunSuite =>
3433

3534
protected val partitioner = new HashPartitioner(2)
@@ -241,42 +240,6 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
241240
protected def generateFatPairRDD(): RDD[(Int, Int)] = {
242241
new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
243242
}
244-
245-
protected def testBasicCheckpoint(sc: SparkContext, reliableCheckpoint: Boolean): Unit = {
246-
val parCollection = sc.makeRDD(1 to 4)
247-
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
248-
checkpoint(flatMappedRDD, reliableCheckpoint)
249-
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
250-
val result = flatMappedRDD.collect()
251-
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
252-
assert(flatMappedRDD.collect() === result)
253-
}
254-
255-
protected def testCompression(checkpointDir: File, compressionCodec: String): Unit = {
256-
val sparkConf = new SparkConf()
257-
sparkConf.set("spark.checkpoint.compress.codec", compressionCodec)
258-
val sc = new SparkContext("local", "test", sparkConf)
259-
sc.setCheckpointDir(checkpointDir.toString)
260-
val initialSize = 20
261-
// Use just one partition for now since compression works best on large data sets.
262-
val collection = sc.makeRDD(1 to initialSize, numSlices = 1)
263-
val flatMappedRDD = collection.flatMap(x => 1 to x)
264-
checkpoint(flatMappedRDD, reliableCheckpoint = true)
265-
assert(flatMappedRDD.collect().length == initialSize * (initialSize + 1)/2,
266-
"The checkpoint was lossy!")
267-
sc.stop()
268-
val checkpointPath = new Path(flatMappedRDD.getCheckpointFile.get)
269-
val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
270-
val fileStatus = fs.listStatus(checkpointPath).find(_.getPath.getName.startsWith("part-")).get
271-
val compressedSize = fileStatus.getLen
272-
assert(compressedSize > 0, "The checkpoint file was not written!")
273-
val compressedInputStream = CompressionCodec.createCodec(sparkConf, compressionCodec)
274-
.compressedInputStream(fs.open(fileStatus.getPath))
275-
val uncompressedSize = ByteStreams.toByteArray(compressedInputStream).length
276-
compressedInputStream.close()
277-
assert(compressedSize < uncompressedSize, "The compression was not successful!")
278-
}
279-
280243
}
281244

282245
/**
@@ -290,14 +253,10 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
290253
super.beforeEach()
291254
checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
292255
checkpointDir.delete()
293-
}
294-
295-
private def startSparkContext(): Unit = {
296256
sc = new SparkContext("local", "test")
297257
sc.setCheckpointDir(checkpointDir.toString)
298258
}
299259

300-
301260
override def afterEach(): Unit = {
302261
try {
303262
Utils.deleteRecursively(checkpointDir)
@@ -309,44 +268,13 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
309268
override def sparkContext: SparkContext = sc
310269

311270
runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
312-
startSparkContext()
313-
testBasicCheckpoint(sc, reliableCheckpoint)
314-
}
315-
316-
runTest("compression with snappy", skipLocalCheckpoint = true) { _: Boolean =>
317-
val sparkConf = new SparkConf()
318-
sparkConf.set("spark.checkpoint.compress.codec", "snappy")
319-
sc = new SparkContext("local", "test", sparkConf)
320-
sc.setCheckpointDir(checkpointDir.toString)
321-
testBasicCheckpoint(sc, reliableCheckpoint = true)
322-
}
323-
324-
runTest("compression with lz4", skipLocalCheckpoint = true) { _: Boolean =>
325-
val sparkConf = new SparkConf()
326-
sparkConf.set("spark.checkpoint.compress.codec", "lz4")
327-
sc = new SparkContext("local", "test", sparkConf)
328-
sc.setCheckpointDir(checkpointDir.toString)
329-
testBasicCheckpoint(sc, reliableCheckpoint = true)
330-
}
331-
332-
runTest("compression with lzf", skipLocalCheckpoint = true) { _: Boolean =>
333-
val sparkConf = new SparkConf()
334-
sparkConf.set("spark.checkpoint.compress.codec", "lzf")
335-
sc = new SparkContext("local", "test", sparkConf)
336-
sc.setCheckpointDir(checkpointDir.toString)
337-
testBasicCheckpoint(sc, reliableCheckpoint = true)
338-
}
339-
340-
runTest("compression size snappy", skipLocalCheckpoint = true) { _: Boolean =>
341-
testCompression(checkpointDir, "snappy")
342-
}
343-
344-
runTest("compression size lzf", skipLocalCheckpoint = true) { _: Boolean =>
345-
testCompression(checkpointDir, "lzf")
346-
}
347-
348-
runTest("compression size lz4", skipLocalCheckpoint = true) { _: Boolean =>
349-
testCompression(checkpointDir, "lz4")
271+
val parCollection = sc.makeRDD(1 to 4)
272+
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
273+
checkpoint(flatMappedRDD, reliableCheckpoint)
274+
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
275+
val result = flatMappedRDD.collect()
276+
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
277+
assert(flatMappedRDD.collect() === result)
350278
}
351279

352280
runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>
@@ -386,15 +314,13 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
386314
}
387315
}
388316

389-
startSparkContext()
390317
testPartitionerCheckpointing(partitioner)
391318

392319
// Test that corrupted partitioner file does not prevent recovery of RDD
393320
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
394321
}
395322

396323
runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
397-
startSparkContext()
398324
testRDD(_.map(x => x.toString), reliableCheckpoint)
399325
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
400326
testRDD(_.filter(_ % 2 == 0), reliableCheckpoint)
@@ -408,7 +334,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
408334
}
409335

410336
runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean =>
411-
startSparkContext()
412337
val parCollection = sc.makeRDD(1 to 4, 2)
413338
val numPartitions = parCollection.partitions.size
414339
checkpoint(parCollection, reliableCheckpoint)
@@ -425,7 +350,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
425350
}
426351

427352
runTest("BlockRDD") { reliableCheckpoint: Boolean =>
428-
startSparkContext()
429353
val blockId = TestBlockId("id")
430354
val blockManager = SparkEnv.get.blockManager
431355
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
@@ -443,22 +367,19 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
443367
}
444368

445369
runTest("ShuffleRDD") { reliableCheckpoint: Boolean =>
446-
startSparkContext()
447370
testRDD(rdd => {
448371
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
449372
new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
450373
}, reliableCheckpoint)
451374
}
452375

453376
runTest("UnionRDD") { reliableCheckpoint: Boolean =>
454-
startSparkContext()
455377
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
456378
testRDD(_.union(otherRDD), reliableCheckpoint)
457379
testRDDPartitions(_.union(otherRDD), reliableCheckpoint)
458380
}
459381

460382
runTest("CartesianRDD") { reliableCheckpoint: Boolean =>
461-
startSparkContext()
462383
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
463384
testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
464385
testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
@@ -482,7 +403,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
482403
}
483404

484405
runTest("CoalescedRDD") { reliableCheckpoint: Boolean =>
485-
startSparkContext()
486406
testRDD(_.coalesce(2), reliableCheckpoint)
487407
testRDDPartitions(_.coalesce(2), reliableCheckpoint)
488408

@@ -505,7 +425,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
505425
}
506426

507427
runTest("CoGroupedRDD") { reliableCheckpoint: Boolean =>
508-
startSparkContext()
509428
val longLineageRDD1 = generateFatPairRDD()
510429

511430
// Collect the RDD as sequences instead of arrays to enable equality tests in testRDD
@@ -524,7 +443,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
524443
}
525444

526445
runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean =>
527-
startSparkContext()
528446
testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
529447
testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
530448

@@ -550,7 +468,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
550468
}
551469

552470
runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean =>
553-
startSparkContext()
554471
testRDD(rdd => {
555472
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
556473
generateFatPairRDD(),
@@ -585,7 +502,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
585502
}
586503

587504
runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean =>
588-
startSparkContext()
589505
val rdd = new BlockRDD[Int](sc, Array.empty[BlockId])
590506
assert(rdd.partitions.size === 0)
591507
assert(rdd.isCheckpointed === false)
@@ -600,7 +516,6 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
600516
}
601517

602518
runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
603-
startSparkContext()
604519
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
605520
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
606521
}
@@ -667,3 +582,42 @@ object CheckpointSuite {
667582
).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
668583
}
669584
}
585+
586+
class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
587+
588+
test("checkpoint compression") {
589+
val checkpointDir = Utils.createTempDir()
590+
try {
591+
val conf = new SparkConf()
592+
.set("spark.checkpoint.compress", "true")
593+
.set("spark.ui.enabled", "false")
594+
sc = new SparkContext("local", "test", conf)
595+
sc.setCheckpointDir(checkpointDir.toString)
596+
val rdd = sc.makeRDD(1 to 20, numSlices = 1)
597+
rdd.checkpoint()
598+
assert(rdd.collect().toSeq === (1 to 20))
599+
600+
// Verify that RDD is checkpointed
601+
assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
602+
603+
val checkpointPath = new Path(rdd.getCheckpointFile.get)
604+
val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
605+
val checkpointFile =
606+
fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get
607+
608+
// Verify the checkpoint file is compressed, in other words, can be decompressed
609+
val compressedInputStream = CompressionCodec.createCodec(conf)
610+
.compressedInputStream(fs.open(checkpointFile))
611+
try {
612+
ByteStreams.toByteArray(compressedInputStream)
613+
} finally {
614+
compressedInputStream.close()
615+
}
616+
617+
// Verify that the compressed content can be read back
618+
assert(rdd.collect().toSeq === (1 to 20))
619+
} finally {
620+
Utils.deleteRecursively(checkpointDir)
621+
}
622+
}
623+
}

0 commit comments

Comments
 (0)