diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00d7e18320a5..8de66f602a60 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -6394,6 +6394,17 @@ def test_invalid_args(self): df.withColumn('mean_v', mean_udf(df['v']).over(ow)) +class DataSourceV2Tests(ReusedSQLTestCase): + def test_pyspark_udf_SPARK_25213(self): + from pyspark.sql.functions import udf + + df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load() + result = df.withColumn('x', udf(lambda x: x, 'int')(df['i'])) + rows = list(map(lambda r: r.asDict(), result.collect())) + expected = [{'i': i, 'j': -i, 'x': i} for i in range(10)] + self.assertEqual(rows, expected) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fe713ff6c785..9b582469962a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, PythonUDF} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} @@ -104,6 +104,9 @@ object DataSourceV2Strategy extends Strategy { } } + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.find(PythonUDF.isScalarPythonUDF).isDefined + } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => @@ -130,10 +133,22 @@ object DataSourceV2Strategy extends Strategy { config) val filterCondition = postScanFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withFilter = if (filterCondition.exists(hasScalarPythonUDF)) { + // add a projection before FilterExec to ensure that the rows are converted to unsafe + val filterExpr = filterCondition.get + FilterExec(filterExpr, ProjectExec(filterExpr.references.toSeq, scan)) + } else { + filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + } // always add the projection, which will produce unsafe rows required by some operators - ProjectExec(project, withFilter) :: Nil + if (project.exists(hasScalarPythonUDF)) { + val references = project.map(_.references).reduce(_ ++ _).toSeq + ProjectExec(project, ProjectExec(references, withFilter)) :: Nil + } else { + ProjectExec(project, withFilter) :: Nil + } case r: StreamingDataSourceV2Relation => // TODO: support operator pushdown for streaming data sources.