Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6394,6 +6394,17 @@ def test_invalid_args(self):
df.withColumn('mean_v', mean_udf(df['v']).over(ow))


class DataSourceV2Tests(ReusedSQLTestCase):
def test_pyspark_udf_SPARK_25213(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal but I would avoid SPARK_25213 postfix at the end just for consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that the tests in Scala include this information somewhere. Is there a better place for it in PySpark? I'm not aware of another way to pass extra metadata, but I'm open to if it there's a better way.

from pyspark.sql.functions import udf

df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test will fail if test classes are not compiled. Can we check if test classes are compiled and then skip if not existent?

result = df.withColumn('x', udf(lambda x: x, 'int')(df['i']))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only tests Project with Scalar PythonUDF? Might be better to also test Filter case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I was just verifying that the fix worked before spending more time on it.

rows = list(map(lambda r: r.asDict(), result.collect()))
expected = [{'i': i, 'j': -i, 'x': i} for i in range(10)]
self.assertEqual(rows, expected)


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable

import org.apache.spark.sql.{sources, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, PythonUDF}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
Expand Down Expand Up @@ -104,6 +104,9 @@ object DataSourceV2Strategy extends Strategy {
}
}

private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
Expand All @@ -130,10 +133,22 @@ object DataSourceV2Strategy extends Strategy {
config)

val filterCondition = postScanFilters.reduceLeftOption(And)
val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)

val withFilter = if (filterCondition.exists(hasScalarPythonUDF)) {
// add a projection before FilterExec to ensure that the rows are converted to unsafe
val filterExpr = filterCondition.get
FilterExec(filterExpr, ProjectExec(filterExpr.references.toSeq, scan))
} else {
filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
}

// always add the projection, which will produce unsafe rows required by some operators
ProjectExec(project, withFilter) :: Nil
if (project.exists(hasScalarPythonUDF)) {
val references = project.map(_.references).reduce(_ ++ _).toSeq
ProjectExec(project, ProjectExec(references, withFilter)) :: Nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add extra Project on top of Filter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The v2 data sources return InternalRow, not UnsafeRow. Python UDFs can't handle InternalRow, so this is intended to add a projection to convert to unsafe before the projection that contains a python UDF.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. It is also used to make sure PythonUDF in top Project takes unsafe row input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we already add Project on top of Filter, we don't need to add another Project here, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one was only added if there was a filter and if that filter ran a UDF. This will add an unnecessary project if both the filter and the project have python UDFs, but I thought that was probably okay. I can add a boolean to signal if the filter caused one to be added already if you think it's worth it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Let's leave as it is now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for leaving as is.

} else {
ProjectExec(project, withFilter) :: Nil
}

case r: StreamingDataSourceV2Relation =>
// TODO: support operator pushdown for streaming data sources.
Expand Down