Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.rdd

import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -40,26 +42,33 @@ private[spark] object InputFileBlockHolder {
/**
* The thread variable for the name of the current file being read. This is used by
* the InputFileName function in Spark SQL.
*
* @note `inputBlock` works somewhat complicatedly. It guarantees that `initialValue`
* is called at the start of a task. Therefore, one atomic reference is created in the task
* thread. After that, read and write happen to the same atomic reference across the parent and
* children threads. This is in order to support a case where write happens in a child thread
* but read happens at its parent thread, for instance, Python UDF execution. See SPARK-28153.
*/
private[this] val inputBlock: InheritableThreadLocal[FileBlock] =
new InheritableThreadLocal[FileBlock] {
override protected def initialValue(): FileBlock = new FileBlock
private[this] val inputBlock: InheritableThreadLocal[AtomicReference[FileBlock]] =
new InheritableThreadLocal[AtomicReference[FileBlock]] {
override protected def initialValue(): AtomicReference[FileBlock] =
new AtomicReference(new FileBlock)
}

/**
* Returns the holding file name or empty string if it is unknown.
*/
def getInputFilePath: UTF8String = inputBlock.get().filePath
def getInputFilePath: UTF8String = inputBlock.get().get().filePath

/**
* Returns the starting offset of the block currently being read, or -1 if it is unknown.
*/
def getStartOffset: Long = inputBlock.get().startOffset
def getStartOffset: Long = inputBlock.get().get().startOffset

/**
* Returns the length of the block being read, or -1 if it is unknown.
*/
def getLength: Long = inputBlock.get().length
def getLength: Long = inputBlock.get().get().length

/**
* Sets the thread-local input block.
Expand All @@ -68,11 +77,17 @@ private[spark] object InputFileBlockHolder {
require(filePath != null, "filePath cannot be null")
require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative")
require(length >= 0, s"length ($length) cannot be negative")
inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
inputBlock.get().set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
}

/**
* Clears the input file block to default value.
*/
def unset(): Unit = inputBlock.remove()

/**
* Initializes thread local by explicitly getting the value. It triggers ThreadLocal's
* initialValue in the parent thread.
*/
def initialize(): Unit = inputBlock.get()
}
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ private[spark] abstract class Task[T](
taskContext
}

InputFileBlockHolder.initialize()
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def test_array_repeat(self):
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
)

def test_input_file_name_udf(self):
df = self.spark.read.text('python/test_support/hello/hello.txt')
df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
file_name = df.collect()[0].file
self.assertTrue("python/test_support/hello/hello.txt" in file_name)


if __name__ == "__main__":
import unittest
Expand Down