diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ca1eb1f4e4a9a..d5e853613b05b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -66,6 +66,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { val f = new ComplexFutureAction[Seq[T]] + val callSite = self.context.getCallSite f.run { // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which @@ -73,6 +74,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 + self.context.setCallSite(callSite) while (results.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob.