Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests/test_pandas_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# limitations under the License.
#
import os
import shutil
import tempfile
import time
import unittest

from pyspark.sql import Row
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message

Expand Down Expand Up @@ -112,6 +115,25 @@ def func(iterator):
expected = df.collect()
self.assertEqual(actual, expected)

# SPARK-33277
def test_map_in_pandas_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 200000, 1, 1).write.parquet(path)

def func(iterator):
for pdf in iterator:
yield pd.DataFrame({'id': [0] * len(pdf)})

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).mapInPandas(func, 'id long').head(), Row(0))
finally:
shutil.rmtree(path)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_map import * # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,25 @@ def test_datasource_with_udf(self):
finally:
shutil.rmtree(path)

# SPARK-33277
def test_pandas_udf_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 200000, 1, 1).write.parquet(path)

@pandas_udf(LongType())
def udf(x):
return pd.Series([0] * len(x))

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(udf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_scalar import * # noqa: F401
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,26 @@ def test_udf_cache(self):
self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')

# SPARK-33277
def test_udf_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 100000, 1, 1).write.parquet(path)

def f(x):
return 0

fUdf = udf(f, LongType())

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python

import java.io.File
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -89,6 +90,7 @@ trait EvalPythonExec extends UnaryExecNode {

inputRDD.mapPartitions { iter =>
val context = TaskContext.get()
val contextAwareIterator = new ContextAwareIterator(iter, context)

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
Expand Down Expand Up @@ -120,7 +122,7 @@ trait EvalPythonExec extends UnaryExecNode {
}.toSeq)

// Add rows to queue to join later with the result.
val projectedRowIter = iter.map { inputRow =>
val projectedRowIter = contextAwareIterator.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
}
Expand All @@ -137,3 +139,53 @@ trait EvalPythonExec extends UnaryExecNode {
}
}
}

/**
* A TaskContext aware iterator.
*
* As the Python evaluation consumes the parent iterator in a separate thread,
* it could consume more data from the parent even after the task ends and the parent is closed.
* Thus, we should use ContextAwareIterator to stop consuming after the task ends.
*/
class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext) extends Iterator[IN] {

private val thread = new AtomicReference[Thread]()

if (iter.hasNext) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this change the thread that iter.hasNext is running? We can add the listeners without checking it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is to make sure the upstream iterator is initialized. The upstream iterator must be initialized earlier as it might register another completion listener and the listener should run later than this one.

val failed = new AtomicBoolean(false)

context.addTaskFailureListener { (_, _) =>
failed.set(true)
}

context.addTaskCompletionListener[Unit] { _ =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes the task completion listener to stop thread runs before this one. Otherwise, it would hang forever. I'm wondering if there is any better solution to avoid this implicit assumption.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task completion lister will wait for the thread to stop within this listener, and the thread will stop soon as it checks !context.isCompleted() && !context.isInterrupted().

var thread = this.thread.get()

// Wait for a while since the writer thread might not reach to consuming the iterator yet.
while (thread == null && !failed.get()) {
// Use `context.wait()` instead of `Thread.sleep()` here since the task completion lister
// works under `synchronized(context)`. We might need to consider to improve in the future.
// It's a bad idea to hold an implicit lock when calling user's listener because it's
// pretty easy to cause surprising deadlock.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit scary. Is there a better way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bad idea to hold an implicit lock when calling user's listener because it's pretty easy to cause surprising deadlock.

Maybe we can fix this first. The this listener doesn't need to rely on an implicit lock.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Let me change the strategy here.

context.wait(10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean Thread.sleep(10)? Object.wait is not supposed to use like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do mean wait. This will run within synchronized(context) and we should release the lock for the writer thread while waiting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize it. It's better to not rely on this in a listener. This is something we should consider to improve in future. It's a bad idea to hold an implicit lock when calling user's listener because it's pretty easy to cause surprising deadlock.


thread = this.thread.get()
}

if (thread != null && thread != Thread.currentThread()) {
// Wait until the writer thread ends.
while (thread.isAlive) {
// Use `context.wait()` instead of `Thread.sleep()` with the same reason above.
context.wait(10)
}
}
}
}

override def hasNext: Boolean = {
thread.set(Thread.currentThread())
!context.isCompleted() && !context.isInterrupted() && iter.hasNext
}

override def next(): IN = iter.next()
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ case class MapInPandasExec(
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val outputTypes = child.schema

val context = TaskContext.get()
val contextAwareIterator = new ContextAwareIterator(inputIter, context)

// Here we wrap it via another row so that Python sides understand it
// as a DataFrame.
val wrappedIter = inputIter.map(InternalRow(_))
val wrappedIter = contextAwareIterator.map(InternalRow(_))

// DO NOT use iter.grouped(). See BatchIterator.
val batchIter =
if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter)

val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
Expand Down