-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25213][PYTHON] Add project to v2 scans before python filters. #22206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6394,6 +6394,17 @@ def test_invalid_args(self): | |
| df.withColumn('mean_v', mean_udf(df['v']).over(ow)) | ||
|
|
||
|
|
||
| class DataSourceV2Tests(ReusedSQLTestCase): | ||
| def test_pyspark_udf_SPARK_25213(self): | ||
| from pyspark.sql.functions import udf | ||
|
|
||
| df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2").load() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this test will fail if test classes are not compiled. Can we check if test classes are compiled and then skip if not existent? |
||
| result = df.withColumn('x', udf(lambda x: x, 'int')(df['i'])) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only tests Project with Scalar PythonUDF? Might be better to also test Filter case.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I was just verifying that the fix worked before spending more time on it. |
||
| rows = list(map(lambda r: r.asDict(), result.collect())) | ||
| expected = [{'i': i, 'j': -i, 'x': i} for i in range(10)] | ||
| self.assertEqual(rows, expected) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 | |
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.sql.{sources, Strategy} | ||
| import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} | ||
| import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, PythonUDF} | ||
| import org.apache.spark.sql.catalyst.planning.PhysicalOperation | ||
| import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} | ||
| import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} | ||
|
|
@@ -104,6 +104,9 @@ object DataSourceV2Strategy extends Strategy { | |
| } | ||
| } | ||
|
|
||
| private def hasScalarPythonUDF(e: Expression): Boolean = { | ||
| e.find(PythonUDF.isScalarPythonUDF).isDefined | ||
| } | ||
|
|
||
| override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { | ||
| case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => | ||
|
|
@@ -130,10 +133,22 @@ object DataSourceV2Strategy extends Strategy { | |
| config) | ||
|
|
||
| val filterCondition = postScanFilters.reduceLeftOption(And) | ||
| val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) | ||
|
|
||
| val withFilter = if (filterCondition.exists(hasScalarPythonUDF)) { | ||
| // add a projection before FilterExec to ensure that the rows are converted to unsafe | ||
| val filterExpr = filterCondition.get | ||
| FilterExec(filterExpr, ProjectExec(filterExpr.references.toSeq, scan)) | ||
| } else { | ||
| filterCondition.map(FilterExec(_, scan)).getOrElse(scan) | ||
| } | ||
|
|
||
| // always add the projection, which will produce unsafe rows required by some operators | ||
| ProjectExec(project, withFilter) :: Nil | ||
| if (project.exists(hasScalarPythonUDF)) { | ||
| val references = project.map(_.references).reduce(_ ++ _).toSeq | ||
| ProjectExec(project, ProjectExec(references, withFilter)) :: Nil | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to add extra Project on top of Filter here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The v2 data sources return
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I see. It is also used to make sure PythonUDF in top Project takes unsafe row input.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: If we already add Project on top of Filter, we don't need to add another Project here, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That one was only added if there was a filter and if that filter ran a UDF. This will add an unnecessary project if both the filter and the project have python UDFs, but I thought that was probably okay. I can add a boolean to signal if the filter caused one to be added already if you think it's worth it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. Let's leave as it is now.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for leaving as is. |
||
| } else { | ||
| ProjectExec(project, withFilter) :: Nil | ||
| } | ||
|
|
||
| case r: StreamingDataSourceV2Relation => | ||
| // TODO: support operator pushdown for streaming data sources. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a big deal but I would avoid
SPARK_25213postfix at the end just for consistency.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that the tests in Scala include this information somewhere. Is there a better place for it in PySpark? I'm not aware of another way to pass extra metadata, but I'm open to if it there's a better way.