@@ -229,39 +229,35 @@ private[sql] abstract class BaseWriterContainer(
229229
230230 protected val dataSchema = relation.dataSchema
231231
232- protected val outputCommitterClass : Class [_ <: FileOutputCommitter ] =
233- relation.outputCommitterClass
234-
235232 protected val outputWriterClass : Class [_ <: OutputWriter ] = relation.outputWriterClass
236233
234+ private var outputFormatClass : Class [_ <: OutputFormat [_, _]] = _
235+
237236 def driverSideSetup (): Unit = {
238237 setupIDs(0 , 0 , 0 )
239238 setupConf()
240239 taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
241- outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext)
240+ relation.prepareForWrite(job)
241+ outputFormatClass = job.getOutputFormatClass
242+ outputCommitter = newOutputCommitter(taskAttemptContext)
242243 outputCommitter.setupJob(jobContext)
243244 }
244245
245246 def executorSideSetup (taskContext : TaskContext ): Unit = {
246247 setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
247248 setupConf()
248249 taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
249- outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext)
250+ outputCommitter = newOutputCommitter(taskAttemptContext)
250251 outputCommitter.setupTask(taskAttemptContext)
251252 initWriters()
252253 }
253254
254- private def newOutputCommitter (
255- clazz : Class [_ <: FileOutputCommitter ],
256- path : String ,
257- context : TaskAttemptContext ): FileOutputCommitter = {
258- val ctor = outputCommitterClass.getConstructor(classOf [Path ], classOf [TaskAttemptContext ])
259- ctor.setAccessible(true )
260-
261- val hadoopPath = new Path (path)
262- val fs = hadoopPath.getFileSystem(serializableConf.value)
263- val qualified = fs.makeQualified(hadoopPath)
264- ctor.newInstance(qualified, context)
255+ private def newOutputCommitter (context : TaskAttemptContext ): FileOutputCommitter = {
256+ outputFormatClass.newInstance().getOutputCommitter(context) match {
257+ case f : FileOutputCommitter => f
258+ case f => sys.error(
259+ s " FileOutputCommitter or its subclass is expected, but got a ${f.getClass.getName}. " )
260+ }
265261 }
266262
267263 private def setupIDs (jobId : Int , splitId : Int , attemptId : Int ): Unit = {
0 commit comments