1717
1818package org .apache .spark .rdd
1919
20- import scala .collection .mutable .ArrayBuffer
21- import scala .collection .mutable .HashSet
20+ import org .apache .hadoop .fs .FileSystem
21+ import org .apache .hadoop .mapred ._
22+ import org .apache .hadoop .util .Progressable
23+
24+ import scala .collection .mutable .{ArrayBuffer , HashSet }
2225import scala .util .Random
2326
24- import org .scalatest .FunSuite
2527import com .google .common .io .Files
26- import org .apache .hadoop .mapreduce . _
27- import org .apache .hadoop .conf .{ Configuration , Configurable }
28-
29- import org . apache . spark . SparkContext . _
28+ import org .apache .hadoop .conf .{ Configurable , Configuration }
29+ import org .apache .hadoop .mapreduce .{ JobContext => NewJobContext , OutputCommitter => NewOutputCommitter ,
30+ OutputFormat => NewOutputFormat , RecordWriter => NewRecordWriter ,
31+ TaskAttemptContext => NewTaskAttempContext }
3032import org .apache .spark .{Partitioner , SharedSparkContext }
33+ import org .apache .spark .SparkContext ._
34+ import org .scalatest .FunSuite
3135
3236class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
3337 test(" aggregateByKey" ) {
@@ -467,7 +471,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
467471 val pairs = sc.parallelize(Array ((new Integer (1 ), new Integer (1 ))))
468472
469473 // No error, non-configurable formats still work
470- pairs.saveAsNewAPIHadoopFile[FakeFormat ](" ignored" )
474+ pairs.saveAsNewAPIHadoopFile[NewFakeFormat ](" ignored" )
471475
472476 /*
473477 Check that configurable formats get configured:
@@ -478,6 +482,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
478482 pairs.saveAsNewAPIHadoopFile[ConfigTestFormat ](" ignored" )
479483 }
480484
485+ test(" saveAsHadoopFile should respect configured output committers" ) {
486+ val pairs = sc.parallelize(Array ((new Integer (1 ), new Integer (1 ))))
487+ val conf = new JobConf ()
488+ conf.setOutputCommitter(classOf [FakeOutputCommitter ])
489+
490+ FakeOutputCommitter .ran = false
491+ pairs.saveAsHadoopFile(" ignored" , pairs.keyClass, pairs.valueClass, classOf [FakeOutputFormat ], conf)
492+
493+ assert(FakeOutputCommitter .ran, " OutputCommitter was never called" )
494+ }
495+
481496 test(" lookup" ) {
482497 val pairs = sc.parallelize(Array ((1 ,2 ), (3 ,4 ), (5 ,6 ), (5 ,7 )))
483498
@@ -621,40 +636,86 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
621636 and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile
622637 tries to instantiate them with Class.newInstance.
623638 */
639+
640+ /*
641+ * Original Hadoop API
642+ */
624643class FakeWriter extends RecordWriter [Integer , Integer ] {
644+ override def write (key : Integer , value : Integer ): Unit = ()
625645
626- def close (p1 : TaskAttemptContext ) = ()
646+ override def close (reporter : Reporter ): Unit = ()
647+ }
648+
649+ class FakeOutputCommitter () extends OutputCommitter () {
650+ override def setupJob (jobContext : JobContext ): Unit = ()
651+
652+ override def needsTaskCommit (taskContext : TaskAttemptContext ): Boolean = true
653+
654+ override def setupTask (taskContext : TaskAttemptContext ): Unit = ()
655+
656+ override def commitTask (taskContext : TaskAttemptContext ): Unit = {
657+ FakeOutputCommitter .ran = true
658+ ()
659+ }
660+
661+ override def abortTask (taskContext : TaskAttemptContext ): Unit = ()
662+ }
663+
664+ /*
665+ * Used to communicate state between the test harness and the OutputCommitter.
666+ */
667+ object FakeOutputCommitter {
668+ var ran = false
669+ }
670+
671+ class FakeOutputFormat () extends OutputFormat [Integer , Integer ]() {
672+ override def getRecordWriter (
673+ ignored : FileSystem ,
674+ job : JobConf , name : String ,
675+ progress : Progressable ): RecordWriter [Integer , Integer ] = {
676+ new FakeWriter ()
677+ }
678+
679+ override def checkOutputSpecs (ignored : FileSystem , job : JobConf ): Unit = ()
680+ }
681+
682+ /*
683+ * New-style Hadoop API
684+ */
685+ class NewFakeWriter extends NewRecordWriter [Integer , Integer ] {
686+
687+ def close (p1 : NewTaskAttempContext ) = ()
627688
628689 def write (p1 : Integer , p2 : Integer ) = ()
629690
630691}
631692
632- class FakeCommitter extends OutputCommitter {
633- def setupJob (p1 : JobContext ) = ()
693+ class NewFakeCommitter extends NewOutputCommitter {
694+ def setupJob (p1 : NewJobContext ) = ()
634695
635- def needsTaskCommit (p1 : TaskAttemptContext ): Boolean = false
696+ def needsTaskCommit (p1 : NewTaskAttempContext ): Boolean = false
636697
637- def setupTask (p1 : TaskAttemptContext ) = ()
698+ def setupTask (p1 : NewTaskAttempContext ) = ()
638699
639- def commitTask (p1 : TaskAttemptContext ) = ()
700+ def commitTask (p1 : NewTaskAttempContext ) = ()
640701
641- def abortTask (p1 : TaskAttemptContext ) = ()
702+ def abortTask (p1 : NewTaskAttempContext ) = ()
642703}
643704
644- class FakeFormat () extends OutputFormat [Integer , Integer ]() {
705+ class NewFakeFormat () extends NewOutputFormat [Integer , Integer ]() {
645706
646- def checkOutputSpecs (p1 : JobContext ) = ()
707+ def checkOutputSpecs (p1 : NewJobContext ) = ()
647708
648- def getRecordWriter (p1 : TaskAttemptContext ): RecordWriter [Integer , Integer ] = {
649- new FakeWriter ()
709+ def getRecordWriter (p1 : NewTaskAttempContext ): NewRecordWriter [Integer , Integer ] = {
710+ new NewFakeWriter ()
650711 }
651712
652- def getOutputCommitter (p1 : TaskAttemptContext ): OutputCommitter = {
653- new FakeCommitter ()
713+ def getOutputCommitter (p1 : NewTaskAttempContext ): NewOutputCommitter = {
714+ new NewFakeCommitter ()
654715 }
655716}
656717
657- class ConfigTestFormat () extends FakeFormat () with Configurable {
718+ class ConfigTestFormat () extends NewFakeFormat () with Configurable {
658719
659720 var setConfCalled = false
660721 def setConf (p1 : Configuration ) = {
@@ -664,7 +725,7 @@ class ConfigTestFormat() extends FakeFormat() with Configurable {
664725
665726 def getConf : Configuration = null
666727
667- override def getRecordWriter (p1 : TaskAttemptContext ): RecordWriter [Integer , Integer ] = {
728+ override def getRecordWriter (p1 : NewTaskAttempContext ): NewRecordWriter [Integer , Integer ] = {
668729 assert(setConfCalled, " setConf was never called" )
669730 super .getRecordWriter(p1)
670731 }
0 commit comments