-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF #21650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
3c2fe9a
b3435b6
490dc09
3015257
cbf310e
78f2ebf
4c9c007
83635da
2bc906d
6b22fea
b25936d
8e995e8
f3a45a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): | |
| 'Result vector from pandas_udf was not the required length'): | ||
| df.select(raise_exception(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_mix_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, udf, col | ||
| df = self.spark.range(10) | ||
| row_by_row_udf = udf(lambda x: x, LongType()) | ||
| pd_udf = pandas_udf(lambda x: x, LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'Can not mix vectorized and non-vectorized UDFs'): | ||
| df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_chained(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10) | ||
|
|
@@ -5060,6 +5049,166 @@ def test_type_annotation(self): | |
| df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) | ||
| self.assertEqual(df.first()[0], 0) | ||
|
|
||
| def test_mixed_udf(self): | ||
| import pandas as pd | ||
| from pyspark.sql.functions import col, udf, pandas_udf | ||
|
|
||
| df = self.spark.range(0, 1).toDF('v') | ||
|
|
||
| # Test mixture of multiple UDFs and Pandas UDFs. | ||
|
|
||
| @udf('int') | ||
| def f1(x): | ||
| assert type(x) == int | ||
| return x + 1 | ||
|
|
||
| @pandas_udf('int') | ||
| def f2(x): | ||
| assert type(x) == pd.Series | ||
| return x + 10 | ||
|
|
||
| @udf('int') | ||
| def f3(x): | ||
| assert type(x) == int | ||
| return x + 100 | ||
|
|
||
| @pandas_udf('int') | ||
| def f4(x): | ||
| assert type(x) == pd.Series | ||
| return x + 1000 | ||
|
|
||
| # Test single expression with chained UDFs | ||
| df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) | ||
| df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
| df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) | ||
| df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) | ||
| df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) | ||
|
|
||
| expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) | ||
| expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) | ||
| expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
| expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) | ||
| expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) | ||
|
|
||
| self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) | ||
| self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) | ||
| self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) | ||
| self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) | ||
| self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) | ||
|
|
||
| # Test multiple mixed UDF expressions in a single projection | ||
| df_multi_1 = df \ | ||
| .withColumn('f1', f1(col('v'))) \ | ||
| .withColumn('f2', f2(col('v'))) \ | ||
| .withColumn('f3', f3(col('v'))) \ | ||
| .withColumn('f4', f4(col('v'))) \ | ||
| .withColumn('f2_f1', f2(col('f1'))) \ | ||
| .withColumn('f3_f1', f3(col('f1'))) \ | ||
| .withColumn('f4_f1', f4(col('f1'))) \ | ||
| .withColumn('f3_f2', f3(col('f2'))) \ | ||
| .withColumn('f4_f2', f4(col('f2'))) \ | ||
| .withColumn('f4_f3', f4(col('f3'))) \ | ||
| .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ | ||
| .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ | ||
| .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ | ||
| .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ | ||
| .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) | ||
|
|
||
| # Test mixed udfs in a single expression | ||
| df_multi_2 = df \ | ||
| .withColumn('f1', f1(col('v'))) \ | ||
| .withColumn('f2', f2(col('v'))) \ | ||
| .withColumn('f3', f3(col('v'))) \ | ||
| .withColumn('f4', f4(col('v'))) \ | ||
| .withColumn('f2_f1', f2(f1(col('v')))) \ | ||
| .withColumn('f3_f1', f3(f1(col('v')))) \ | ||
| .withColumn('f4_f1', f4(f1(col('v')))) \ | ||
| .withColumn('f3_f2', f3(f2(col('v')))) \ | ||
| .withColumn('f4_f2', f4(f2(col('v')))) \ | ||
| .withColumn('f4_f3', f4(f3(col('v')))) \ | ||
| .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ | ||
| .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ | ||
| .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ | ||
| .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ | ||
| .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) | ||
|
|
||
| expected = df \ | ||
| .withColumn('f1', df['v'] + 1) \ | ||
| .withColumn('f2', df['v'] + 10) \ | ||
| .withColumn('f3', df['v'] + 100) \ | ||
| .withColumn('f4', df['v'] + 1000) \ | ||
| .withColumn('f2_f1', df['v'] + 11) \ | ||
| .withColumn('f3_f1', df['v'] + 101) \ | ||
| .withColumn('f4_f1', df['v'] + 1001) \ | ||
| .withColumn('f3_f2', df['v'] + 110) \ | ||
| .withColumn('f4_f2', df['v'] + 1010) \ | ||
| .withColumn('f4_f3', df['v'] + 1100) \ | ||
| .withColumn('f3_f2_f1', df['v'] + 111) \ | ||
| .withColumn('f4_f2_f1', df['v'] + 1011) \ | ||
| .withColumn('f4_f3_f1', df['v'] + 1101) \ | ||
| .withColumn('f4_f3_f2', df['v'] + 1110) \ | ||
| .withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
|
|
||
| self.assertEquals(expected.collect(), df_multi_1.collect()) | ||
| self.assertEquals(expected.collect(), df_multi_2.collect()) | ||
|
|
||
| def test_mixed_udf_and_sql(self): | ||
| import pandas as pd | ||
| from pyspark.sql import Column | ||
| from pyspark.sql.functions import udf, pandas_udf | ||
|
|
||
| df = self.spark.range(0, 1).toDF('v') | ||
|
|
||
| # Test mixture of UDFs, Pandas UDFs and SQL expression. | ||
|
|
||
| @udf('int') | ||
| def f1(x): | ||
| assert type(x) == int | ||
| return x + 1 | ||
|
|
||
| def f2(x): | ||
|
||
| assert type(x) == Column | ||
| return x + 10 | ||
|
|
||
| @pandas_udf('int') | ||
| def f3(x): | ||
| assert type(x) == pd.Series | ||
| return x + 100 | ||
|
|
||
| df1 = df.withColumn('f1', f1(df['v'])) \ | ||
| .withColumn('f2', f2(df['v'])) \ | ||
| .withColumn('f3', f3(df['v'])) \ | ||
| .withColumn('f1_f2', f1(f2(df['v']))) \ | ||
| .withColumn('f1_f3', f1(f3(df['v']))) \ | ||
| .withColumn('f2_f1', f2(f1(df['v']))) \ | ||
| .withColumn('f2_f3', f2(f3(df['v']))) \ | ||
| .withColumn('f3_f1', f3(f1(df['v']))) \ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks combination between f1 and f3 duplicating few tests in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the way the test is written is that I am trying to test many combinations so there are some dup cases. Do you prefer that I remove these?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea.. I know it's still minor since the elapsed time will be virtually the same but recently the build / test time was an issue, and I wonder if there's better way then avoding duplicated tests for now..
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was discussed here #21845
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I don't think it's necessary (we are only likely to remove a few cases and like you said, the test time is virtually the same) and helps the readability of the tests (so it doesn't look like some test cases are missed). But if that's the preferred practice I can remove duplicate cases in the next commit.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay to leave it too here since it's clear they are virtually the same but let's remove duplicated tests or orthogonal tests next time.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. I will keep that in mind next time. |
||
| .withColumn('f3_f2', f3(f2(df['v']))) \ | ||
| .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ | ||
| .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ | ||
| .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ | ||
| .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ | ||
| .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ | ||
| .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
|
|
||
| expected = df.withColumn('f1', df['v'] + 1) \ | ||
| .withColumn('f2', df['v'] + 10) \ | ||
| .withColumn('f3', df['v'] + 100) \ | ||
| .withColumn('f1_f2', df['v'] + 11) \ | ||
| .withColumn('f1_f3', df['v'] + 101) \ | ||
| .withColumn('f2_f1', df['v'] + 11) \ | ||
| .withColumn('f2_f3', df['v'] + 110) \ | ||
| .withColumn('f3_f1', df['v'] + 101) \ | ||
| .withColumn('f3_f2', df['v'] + 110) \ | ||
| .withColumn('f1_f2_f3', df['v'] + 111) \ | ||
| .withColumn('f1_f3_f2', df['v'] + 111) \ | ||
| .withColumn('f2_f1_f3', df['v'] + 111) \ | ||
| .withColumn('f2_f3_f1', df['v'] + 111) \ | ||
| .withColumn('f3_f1_f2', df['v'] + 111) \ | ||
| .withColumn('f3_f2_f1', df['v'] + 111) | ||
|
|
||
| self.assertEquals(expected.collect(), df1.collect()) | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not _have_pandas or not _have_pyarrow, | ||
|
|
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df): | |
| F.col('temp0.key') == F.col('temp1.key')) | ||
| self.assertEquals(res.count(), 5) | ||
|
|
||
| def test_mixed_scalar_udfs_followed_by_grouby_apply(self): | ||
| import pandas as pd | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a big deal at all really .. but I would swap the import order (thridparty, pyspark) |
||
| from pyspark.sql.functions import udf, pandas_udf, PandasUDFType | ||
|
|
||
| df = self.spark.range(0, 10).toDF('v1') | ||
| df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ | ||
| .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) | ||
|
|
||
| result = df.groupby() \ | ||
| .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), | ||
| 'sum int', | ||
| PandasUDFType.GROUPED_MAP)) | ||
|
|
||
| self.assertEquals(result.collect()[0]['sum'], 165) | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not _have_pandas or not _have_pyarrow, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ import scala.collection.mutable | |
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.api.python.PythonEvalType | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
|
|
@@ -94,36 +95,60 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |
| */ | ||
| object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | ||
|
|
||
| private def hasPythonUDF(e: Expression): Boolean = { | ||
| private case class EvalTypeHolder(private var evalType: Int = -1) { | ||
|
||
| def isSet: Boolean = evalType >= 0 | ||
|
|
||
| def set(evalType: Int): Unit = { | ||
| if (isSet && evalType != this.evalType) { | ||
| throw new IllegalStateException("Cannot reset eval type to a different value") | ||
| } else { | ||
| this.evalType = evalType | ||
| } | ||
| } | ||
|
|
||
| def get(): Int = { | ||
| if (!isSet) { | ||
| throw new IllegalStateException("Eval type is not set") | ||
| } else { | ||
| evalType | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def hasScalarPythonUDF(e: Expression): Boolean = { | ||
| e.find(PythonUDF.isScalarPythonUDF).isDefined | ||
| } | ||
|
|
||
| private def canEvaluateInPython(e: PythonUDF): Boolean = { | ||
| e.children match { | ||
| // single PythonUDF child could be chained and evaluated in Python | ||
| case Seq(u: PythonUDF) => canEvaluateInPython(u) | ||
| case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) | ||
| // Python UDF can't be evaluated directly in JVM | ||
| case children => !children.exists(hasPythonUDF) | ||
| case children => !children.exists(hasScalarPythonUDF) | ||
| } | ||
| } | ||
|
|
||
| private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { | ||
| case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) | ||
| case e => e.children.flatMap(collectEvaluatableUDF) | ||
| private def collectEvaluableUDFs( | ||
| expr: Expression, | ||
| firstEvalType: EvalTypeHolder): Seq[PythonUDF] = expr match { | ||
| case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) | ||
| && (!firstEvalType.isSet || firstEvalType.get == udf.evalType) | ||
| && canEvaluateInPython(udf) => | ||
| firstEvalType.set(udf.evalType) | ||
| Seq(udf) | ||
| case e => e.children.flatMap(collectEvaluableUDFs(_, firstEvalType)) | ||
| } | ||
|
|
||
| def apply(plan: SparkPlan): SparkPlan = plan transformUp { | ||
| // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker | ||
| // Therefore we don't need to extract the UDFs | ||
| case plan: FlatMapGroupsInPandasExec => plan | ||
|
||
| case plan: SparkPlan => extract(plan) | ||
| } | ||
|
|
||
| /** | ||
| * Extract all the PythonUDFs from the current operator and evaluate them before the operator. | ||
| */ | ||
| private def extract(plan: SparkPlan): SparkPlan = { | ||
| val udfs = plan.expressions.flatMap(collectEvaluatableUDF) | ||
| val firstEvalType = new EvalTypeHolder | ||
| val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, firstEvalType)) | ||
| // ignore the PythonUDF that come from second/third aggregate, which is not used | ||
| .filter(udf => udf.references.subsetOf(plan.inputSet)) | ||
| if (udfs.isEmpty) { | ||
|
|
@@ -167,7 +192,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
| case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => | ||
| BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) | ||
| case _ => | ||
| throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") | ||
| throw new AnalysisException( | ||
|
||
| "Expected either Scalar Pandas UDFs or Batched UDFs but got both") | ||
| } | ||
|
|
||
| attributeMap ++= validUdfs.zip(resultAttrs) | ||
|
|
@@ -205,7 +231,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
| case filter: FilterExec => | ||
| val (candidates, nonDeterministic) = | ||
| splitConjunctivePredicates(filter.condition).partition(_.deterministic) | ||
| val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) | ||
| val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) | ||
| if (pushDown.nonEmpty) { | ||
| val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) | ||
| FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks testing udf + udf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the way the test is written is that I am trying to test many combinations so some combinations might not be mixed UDF. Do you prefer that I remove these cases?