diff --git a/python/pyspark/sql/tests/test_pandas_map.py b/python/pyspark/sql/tests/test_pandas_map.py index d53face70220..53a536f5d8da 100644 --- a/python/pyspark/sql/tests/test_pandas_map.py +++ b/python/pyspark/sql/tests/test_pandas_map.py @@ -15,9 +15,12 @@ # limitations under the License. # import os +import shutil +import tempfile import time import unittest +from pyspark.sql import Row from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -112,6 +115,25 @@ def func(iterator): expected = df.collect() self.assertEqual(actual, expected) + # SPARK-33277 + def test_map_in_pandas_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 200000, 1, 1).write.parquet(path) + + def func(iterator): + for pdf in iterator: + yield pd.DataFrame({'id': [0] * len(pdf)}) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).mapInPandas(func, 'id long').head(), Row(0)) + finally: + shutil.rmtree(path) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_map import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2eb2dec00106..a170f5532ee9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -1152,6 +1152,25 @@ def test_datasource_with_udf(self): finally: shutil.rmtree(path) + # SPARK-33277 + def test_pandas_udf_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 200000, 1, 1).write.parquet(path) + + @pandas_udf(LongType()) + def udf(x): + return pd.Series([0] * len(x)) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(udf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_scalar import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index bfc55dff9454..5b1fc604565a 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -674,6 +674,26 @@ def test_udf_cache(self): self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution() .withCachedData().getClass().getSimpleName(), 'InMemoryRelation') + # SPARK-33277 + def test_udf_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 100000, 1, 1).write.parquet(path) + + def f(x): + return 0 + + fUdf = udf(f, LongType()) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + class UDFInitializationTests(unittest.TestCase): def tearDown(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 7c476ab03c00..e3b2c55c4a6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import java.io.File +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.mutable.ArrayBuffer @@ -89,6 +90,7 @@ trait EvalPythonExec extends UnaryExecNode { inputRDD.mapPartitions { iter => val context = TaskContext.get() + val contextAwareIterator = new ContextAwareIterator(iter, context) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -120,7 +122,7 @@ trait EvalPythonExec extends UnaryExecNode { }.toSeq) // Add rows to queue to join later with the result. - val projectedRowIter = iter.map { inputRow => + val projectedRowIter = contextAwareIterator.map { inputRow => queue.add(inputRow.asInstanceOf[UnsafeRow]) projection(inputRow) } @@ -137,3 +139,53 @@ trait EvalPythonExec extends UnaryExecNode { } } } + +/** + * A TaskContext aware iterator. + * + * As the Python evaluation consumes the parent iterator in a separate thread, + * it could consume more data from the parent even after the task ends and the parent is closed. + * Thus, we should use ContextAwareIterator to stop consuming after the task ends. + */ +class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext) extends Iterator[IN] { + + private val thread = new AtomicReference[Thread]() + + if (iter.hasNext) { + val failed = new AtomicBoolean(false) + + context.addTaskFailureListener { (_, _) => + failed.set(true) + } + + context.addTaskCompletionListener[Unit] { _ => + var thread = this.thread.get() + + // Wait for a while since the writer thread might not reach to consuming the iterator yet. + while (thread == null && !failed.get()) { + // Use `context.wait()` instead of `Thread.sleep()` here since the task completion lister + // works under `synchronized(context)`. We might need to consider to improve in the future. + // It's a bad idea to hold an implicit lock when calling user's listener because it's + // pretty easy to cause surprising deadlock. + context.wait(10) + + thread = this.thread.get() + } + + if (thread != null && thread != Thread.currentThread()) { + // Wait until the writer thread ends. + while (thread.isAlive) { + // Use `context.wait()` instead of `Thread.sleep()` with the same reason above. + context.wait(10) + } + } + } + } + + override def hasNext: Boolean = { + thread.set(Thread.currentThread()) + !context.isCompleted() && !context.isInterrupted() && iter.hasNext + } + + override def next(): IN = iter.next() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala index 2bb808119c0a..7fc18f885a2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala @@ -61,16 +61,17 @@ case class MapInPandasExec( val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val outputTypes = child.schema + val context = TaskContext.get() + val contextAwareIterator = new ContextAwareIterator(inputIter, context) + // Here we wrap it via another row so that Python sides understand it // as a DataFrame. - val wrappedIter = inputIter.map(InternalRow(_)) + val wrappedIter = contextAwareIterator.map(InternalRow(_)) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) - val context = TaskContext.get() - val columnarBatchIter = new ArrowPythonRunner( chainedFunc, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,