From ada6e924597cbe247c8be47fd80df38b79cb34e5 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 23 Aug 2018 10:37:19 -0700 Subject: [PATCH 1/4] SPARK-25213: Add project to v2 scans before python filters. --- python/pyspark/sql/tests.py | 11 +++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00d7e18320a5..4f2c738bafcf 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -6394,6 +6394,17 @@ def test_invalid_args(self): df.withColumn('mean_v', mean_udf(df['v']).over(ow)) +class DataSourceV2Tests(ReusedSQLTestCase): + def test_pyspark_udf_SPARK_25213(self): + from pyspark.sql.functions import udf + + df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load() + result = datasource_v2_df.withColumn('x', udf(lambda x: x, 'int')(datasource_v2_df['i'])) + rows = list(map(lambda r: r.asDict(), result.collect())) + expected = [ {'i': i, 'j': -i, 'x': i} for i in range(10) ] + self.assertEqual(rows, expected) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fe713ff6c785..cbc8a972d9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, PythonUDF} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} @@ -104,6 +104,9 @@ object DataSourceV2Strategy extends Strategy { } } + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.find(PythonUDF.isScalarPythonUDF).isDefined + } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => @@ -130,7 +133,14 @@ object DataSourceV2Strategy extends Strategy { config) val filterCondition = postScanFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withFilter = if (filterCondition.exists(hasScalarPythonUDF)) { + // add a projection before FilterExec to ensure that the rows are converted to unsafe + val filterExpr = filterCondition.get + FilterExec(filterExpr, ProjectExec(filterExpr.references.toSeq, scan)) + } else { + filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + } // always add the projection, which will produce unsafe rows required by some operators ProjectExec(project, withFilter) :: Nil From c2e1bc7e21a82605e884a40c7c5a553a1b711b51 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 23 Aug 2018 10:46:54 -0700 Subject: [PATCH 2/4] SPARK-25213: Add project to v2 scans before python projections. --- .../execution/datasources/v2/DataSourceV2Strategy.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index cbc8a972d9b6..9b582469962a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -143,7 +143,12 @@ object DataSourceV2Strategy extends Strategy { } // always add the projection, which will produce unsafe rows required by some operators - ProjectExec(project, withFilter) :: Nil + if (project.exists(hasScalarPythonUDF)) { + val references = project.map(_.references).reduce(_ ++ _).toSeq + ProjectExec(project, ProjectExec(references, withFilter)) :: Nil + } else { + ProjectExec(project, withFilter) :: Nil + } case r: StreamingDataSourceV2Relation => // TODO: support operator pushdown for streaming data sources. From c49157e94634ea964f78ce671d48c52dc9b43e6b Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 23 Aug 2018 10:48:19 -0700 Subject: [PATCH 3/4] SPARK-24213: Fix pyspark style. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4f2c738bafcf..b365f7378004 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -6401,7 +6401,7 @@ def test_pyspark_udf_SPARK_25213(self): df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load() result = datasource_v2_df.withColumn('x', udf(lambda x: x, 'int')(datasource_v2_df['i'])) rows = list(map(lambda r: r.asDict(), result.collect())) - expected = [ {'i': i, 'j': -i, 'x': i} for i in range(10) ] + expected = [{'i': i, 'j': -i, 'x': i} for i in range(10)] self.assertEqual(rows, expected) From 550368eaeebdc87f2c89bad7214f2624784eeb04 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 23 Aug 2018 16:40:29 -0700 Subject: [PATCH 4/4] SPARK-25213: Fix PySpark test. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b365f7378004..8de66f602a60 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -6399,7 +6399,7 @@ def test_pyspark_udf_SPARK_25213(self): from pyspark.sql.functions import udf df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load() - result = datasource_v2_df.withColumn('x', udf(lambda x: x, 'int')(datasource_v2_df['i'])) + result = df.withColumn('x', udf(lambda x: x, 'int')(df['i'])) rows = list(map(lambda r: r.asDict(), result.collect())) expected = [{'i': i, 'j': -i, 'x': i} for i in range(10)] self.assertEqual(rows, expected)