diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index ff2f58d81142..bfe8152d4dee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.util.concurrent.atomic.AtomicReference + import org.apache.spark.unsafe.types.UTF8String /** @@ -40,26 +42,33 @@ private[spark] object InputFileBlockHolder { /** * The thread variable for the name of the current file being read. This is used by * the InputFileName function in Spark SQL. + * + * @note `inputBlock` works somewhat complicatedly. It guarantees that `initialValue` + * is called at the start of a task. Therefore, one atomic reference is created in the task + * thread. After that, read and write happen to the same atomic reference across the parent and + * children threads. This is in order to support a case where write happens in a child thread + * but read happens at its parent thread, for instance, Python UDF execution. See SPARK-28153. */ - private[this] val inputBlock: InheritableThreadLocal[FileBlock] = - new InheritableThreadLocal[FileBlock] { - override protected def initialValue(): FileBlock = new FileBlock + private[this] val inputBlock: InheritableThreadLocal[AtomicReference[FileBlock]] = + new InheritableThreadLocal[AtomicReference[FileBlock]] { + override protected def initialValue(): AtomicReference[FileBlock] = + new AtomicReference(new FileBlock) } /** * Returns the holding file name or empty string if it is unknown. */ - def getInputFilePath: UTF8String = inputBlock.get().filePath + def getInputFilePath: UTF8String = inputBlock.get().get().filePath /** * Returns the starting offset of the block currently being read, or -1 if it is unknown. */ - def getStartOffset: Long = inputBlock.get().startOffset + def getStartOffset: Long = inputBlock.get().get().startOffset /** * Returns the length of the block being read, or -1 if it is unknown. */ - def getLength: Long = inputBlock.get().length + def getLength: Long = inputBlock.get().get().length /** * Sets the thread-local input block. @@ -68,11 +77,17 @@ private[spark] object InputFileBlockHolder { require(filePath != null, "filePath cannot be null") require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") require(length >= 0, s"length ($length) cannot be negative") - inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + inputBlock.get().set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) } /** * Clears the input file block to default value. */ def unset(): Unit = inputBlock.remove() + + /** + * Initializes thread local by explicitly getting the value. It triggers ThreadLocal's + * initialValue in the parent thread. + */ + def initialize(): Unit = inputBlock.get() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 182f479fb0dd..daed55cc131c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -100,6 +100,7 @@ private[spark] abstract class Task[T]( taskContext } + InputFileBlockHolder.initialize() TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 38e7f39a2711..f65fe885ef7d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -847,6 +847,14 @@ def test_input_file_name_reset_for_rdd(self): for result in results: self.assertEqual(result[0], '') + def test_input_file_name_udf(self): + from pyspark.sql.functions import udf, input_file_name + + df = self.spark.read.text('python/test_support/hello/hello.txt') + df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file')) + file_name = df.collect()[0].file + self.assertTrue("python/test_support/hello/hello.txt" in file_name) + def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization