1818package org .apache .spark .rdd
1919
2020import java .io .{FileNotFoundException , IOException }
21+ import java .security .PrivilegedExceptionAction
2122import java .text .SimpleDateFormat
2223import java .util .{Date , Locale }
2324
@@ -29,6 +30,7 @@ import org.apache.hadoop.mapred._
2930import org .apache .hadoop .mapred .lib .CombineFileSplit
3031import org .apache .hadoop .mapreduce .TaskType
3132import org .apache .hadoop .mapreduce .lib .input .FileInputFormat
33+ import org .apache .hadoop .security .UserGroupInformation
3234import org .apache .hadoop .util .ReflectionUtils
3335
3436import org .apache .spark ._
@@ -124,6 +126,8 @@ class HadoopRDD[K, V](
124126 minPartitions)
125127 }
126128
129+ private val doAsUserName = UserGroupInformation .getCurrentUser.getUserName
130+
127131 protected val jobConfCacheKey : String = " rdd_%d_job_conf" .format(id)
128132
129133 protected val inputFormatCacheKey : String = " rdd_%d_input_format" .format(id)
@@ -220,7 +224,7 @@ class HadoopRDD[K, V](
220224 }
221225 }
222226
223- override def compute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
227+ def doCompute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
224228 val iter = new NextIterator [(K , V )] {
225229
226230 private val split = theSplit.asInstanceOf [HadoopPartition ]
@@ -326,7 +330,7 @@ class HadoopRDD[K, V](
326330 if (getBytesReadCallback.isDefined) {
327331 updateBytesRead()
328332 } else if (split.inputSplit.value.isInstanceOf [FileSplit ] ||
329- split.inputSplit.value.isInstanceOf [CombineFileSplit ]) {
333+ split.inputSplit.value.isInstanceOf [CombineFileSplit ]) {
330334 // If we can't get the bytes read from the FS stats, fall back to the split size,
331335 // which may be inaccurate.
332336 try {
@@ -342,6 +346,29 @@ class HadoopRDD[K, V](
342346 new InterruptibleIterator [(K , V )](context, iter)
343347 }
344348
349+ override def compute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
350+ val ugi = UserGroupInformation .getCurrentUser
351+
352+ if (ugi.getUserName == doAsUserName) {
353+ doCompute(theSplit : Partition , context : TaskContext )
354+ } else {
355+ val doAsAction = new PrivilegedExceptionAction [InterruptibleIterator [(K , V )]]() {
356+ override def run (): InterruptibleIterator [(K , V )] = {
357+ try {
358+ doCompute(theSplit : Partition , context : TaskContext )
359+ } catch {
360+ case e : Exception =>
361+ log.error(" Error when HadoopRDD computing: " , e)
362+ throw e
363+ }
364+ }
365+ }
366+
367+ val proxyUgi = UserGroupInformation .createProxyUser(doAsUserName, ugi)
368+ proxyUgi.doAs(doAsAction)
369+ }
370+ }
371+
345372 /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
346373 @ DeveloperApi
347374 def mapPartitionsWithInputSplit [U : ClassTag ](
0 commit comments