From 27fd57d4deccbf409ae11cf02e32f44eed683b51 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Aug 2018 14:22:50 +0000 Subject: [PATCH 01/14] Fix file strategy to exclude python UDF filters --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ .../datasources/FileSourceStrategy.scala | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00d7e18320a5..3ff1a73c0310 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3367,6 +3367,24 @@ def test_ignore_column_of_all_nulls(self): finally: shutil.rmtree(path) + def test_datasource_with_udf_filter_lit_input(self): + # SPARK-24721 + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + from pyspark.sql.functions import udf, lit, col + + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + df = self.spark.read.csv(path) + # Test that filter with lit inputs works with data source + result1 = df.filter(udf(lambda x: False, 'boolean')(lit(1))) + result2 = df.filter(udf(lambda : False, 'boolean')()) + + self.assertEquals(0, result1.count()) + self.assertEquals(0, result2.count()) + finally: + shutil.rmtree(path) + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index fe27b78bf336..f7c4130d9046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -145,12 +145,16 @@ object FileSourceStrategy extends Strategy with Logging { // - bucket keys only - optionally used to prune files to read // - keys stored in the data only - optionally used to skip groups of data in files // - filters that need to be evaluated again after the scan + val filterSet = ExpressionSet(filters) + // SPARK-24721: Filter out Python UDFs, otherwise ExtractPythonUDF rule will throw exception + val validFilters = filters.filter(_.collectFirst{case e: PythonUDF => e}.isEmpty) + // The attribute name of predicate could be different than the one in schema in case of // case insensitive, we should change them to match the one in schema, so we do not need to // worry about case sensitivity anymore. - val normalizedFilters = filters.map { e => + val normalizedFilters = validFilters.map { e => e transform { case a: AttributeReference => a.withName(l.output.find(_.semanticEquals(a)).get.name) From 499acde2df4be5a2af11336ebb6ff4fa7074ec18 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Aug 2018 15:08:02 +0000 Subject: [PATCH 02/14] Remove white space --- .../spark/sql/execution/datasources/FileSourceStrategy.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index f7c4130d9046..2848675635ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -145,7 +145,6 @@ object FileSourceStrategy extends Strategy with Logging { // - bucket keys only - optionally used to prune files to read // - keys stored in the data only - optionally used to skip groups of data in files // - filters that need to be evaluated again after the scan - val filterSet = ExpressionSet(filters) // SPARK-24721: Filter out Python UDFs, otherwise ExtractPythonUDF rule will throw exception From 595399eca621d410ae3b0f0de83f94934aecdedb Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Aug 2018 15:09:14 +0000 Subject: [PATCH 03/14] Fix style --- .../spark/sql/execution/datasources/FileSourceStrategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 2848675635ca..d082f5f6c65b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -148,7 +148,7 @@ object FileSourceStrategy extends Strategy with Logging { val filterSet = ExpressionSet(filters) // SPARK-24721: Filter out Python UDFs, otherwise ExtractPythonUDF rule will throw exception - val validFilters = filters.filter(_.collectFirst{case e: PythonUDF => e}.isEmpty) + val validFilters = filters.filter(_.collectFirst{ case e: PythonUDF => e }.isEmpty) // The attribute name of predicate could be different than the one in schema in case of // case insensitive, we should change them to match the one in schema, so we do not need to From 470d63af306a8f4eb25859a0394e62a0d6cb9b41 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Aug 2018 15:12:04 +0000 Subject: [PATCH 04/14] Fix python style --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3ff1a73c0310..a7e181e20b1b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3378,7 +3378,7 @@ def test_datasource_with_udf_filter_lit_input(self): df = self.spark.read.csv(path) # Test that filter with lit inputs works with data source result1 = df.filter(udf(lambda x: False, 'boolean')(lit(1))) - result2 = df.filter(udf(lambda : False, 'boolean')()) + result2 = df.filter(udf(lambda: False, 'boolean')()) self.assertEquals(0, result1.count()) self.assertEquals(0, result2.count()) From cfd568e2fe429c7959264a759bc2fd0b34b03eea Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 15 Aug 2018 19:56:35 +0000 Subject: [PATCH 05/14] Add test for data source and data source v2; Fix projection in EvalPythonExec; Move logic to ExtractPythonUDFs --- python/pyspark/sql/tests.py | 29 +++++++++++++------ .../datasources/FileSourceStrategy.scala | 5 +--- .../sql/execution/python/EvalPythonExec.scala | 7 +++-- .../execution/python/ExtractPythonUDFs.scala | 6 +++- 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a7e181e20b1b..d866a18df420 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3367,21 +3367,32 @@ def test_ignore_column_of_all_nulls(self): finally: shutil.rmtree(path) + # SPARK-24721 def test_datasource_with_udf_filter_lit_input(self): - # SPARK-24721 + import pandas as pd + import numpy as np + from pyspark.sql.functions import udf, pandas_udf, lit, col + path = tempfile.mkdtemp() shutil.rmtree(path) try: - from pyspark.sql.functions import udf, lit, col - self.spark.range(1).write.mode("overwrite").format('csv').save(path) - df = self.spark.read.csv(path) - # Test that filter with lit inputs works with data source - result1 = df.filter(udf(lambda x: False, 'boolean')(lit(1))) - result2 = df.filter(udf(lambda: False, 'boolean')()) + filesource_df = self.spark.read.csv(path) + datasource_df = self.spark.read.format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load() + datasource_v2_df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load() + + filter1 = udf(lambda: False, 'boolean')() + filter2 = udf(lambda x: False, 'boolean')(lit(1)) + filter3 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [filter1, filter2, filter3]: + result = df.filter(f) + result.explain(True) + self.assertEquals(0, result.count()) - self.assertEquals(0, result1.count()) - self.assertEquals(0, result2.count()) finally: shutil.rmtree(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index d082f5f6c65b..fe27b78bf336 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -147,13 +147,10 @@ object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - // SPARK-24721: Filter out Python UDFs, otherwise ExtractPythonUDF rule will throw exception - val validFilters = filters.filter(_.collectFirst{ case e: PythonUDF => e }.isEmpty) - // The attribute name of predicate could be different than the one in schema in case of // case insensitive, we should change them to match the one in schema, so we do not need to // worry about case sensitivity anymore. - val normalizedFilters = validFilters.map { e => + val normalizedFilters = filters.map { e => e transform { case a: AttributeReference => a.withName(l.output.find(_.semanticEquals(a)).get.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 04c7dfdd4e20..d8c60bb6f85e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -117,15 +117,16 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } }.toArray }.toArray - val projection = newMutableProjection(allInputs, child.output) + val projection = UnsafeProjection.create(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) }) // Add rows to queue to join later with the result. val projectedRowIter = iter.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - projection(inputRow) + val unsafeRow = projection(inputRow) + queue.add(unsafeRow.asInstanceOf[UnsafeRow]) + unsafeRow } val outputRowIterator = evaluate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index cb75874be32e..25db06f58597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec /** @@ -133,6 +134,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { + // SPARK-24721: Ignore Python UDFs in DataSourceScan and DataSourceV2Scan + case plan: DataSourceScanExec => plan + case plan: DataSourceV2ScanExec => plan case plan: SparkPlan => extract(plan) } From 24eebdcf5944531fab273713b3526b0d0f809007 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 15 Aug 2018 20:07:02 +0000 Subject: [PATCH 06/14] Fix style --- python/pyspark/sql/tests.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d866a18df420..02be182336ec 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3378,9 +3378,11 @@ def test_datasource_with_udf_filter_lit_input(self): try: self.spark.range(1).write.mode("overwrite").format('csv').save(path) filesource_df = self.spark.read.csv(path) - datasource_df = self.spark.read.format("org.apache.spark.sql.sources.SimpleScanSource") \ + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ .option('from', 0).option('to', 1).load() - datasource_v2_df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ .load() filter1 = udf(lambda: False, 'boolean')() From 96d257010bce377da68030fb6a0f521152c3f4f3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 15 Aug 2018 21:22:07 +0000 Subject: [PATCH 07/14] Split tests; Fix EvalPythonExec --- python/pyspark/sql/tests.py | 35 +++++++++++++++---- .../sql/execution/python/EvalPythonExec.scala | 10 +++--- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 02be182336ec..3a69d4e8602d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3369,9 +3369,7 @@ def test_ignore_column_of_all_nulls(self): # SPARK-24721 def test_datasource_with_udf_filter_lit_input(self): - import pandas as pd - import numpy as np - from pyspark.sql.functions import udf, pandas_udf, lit, col + from pyspark.sql.functions import udf, lit, col path = tempfile.mkdtemp() shutil.rmtree(path) @@ -3387,12 +3385,10 @@ def test_datasource_with_udf_filter_lit_input(self): filter1 = udf(lambda: False, 'boolean')() filter2 = udf(lambda x: False, 'boolean')(lit(1)) - filter3 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [filter1, filter2, filter3]: + for f in [filter1, filter2]: result = df.filter(f) - result.explain(True) self.assertEquals(0, result.count()) finally: @@ -5300,6 +5296,33 @@ def f3(x): self.assertEquals(expected.collect(), df1.collect()) + def test_datasource_with_udf_filter_lit_input(self): + # Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pandas UDF + # This needs to a separate test because Arrow dependency is optional + import pandas as pd + import numpy as np + from pyspark.sql.functions import pandas_udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.csv(path) + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load() + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load() + + f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.filter(f) + self.assertEquals(0, result.count()) + + finally: + shutil.rmtree(path) @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index d8c60bb6f85e..c6005f1568d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -117,16 +117,18 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } }.toArray }.toArray - val projection = UnsafeProjection.create(allInputs, child.output) + + // Project input rows to unsafe row so we can put it in the row queue + val unsafeProjection = UnsafeProjection.create(child.output, child.output) + val prunedProjection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) }) // Add rows to queue to join later with the result. val projectedRowIter = iter.map { inputRow => - val unsafeRow = projection(inputRow) - queue.add(unsafeRow.asInstanceOf[UnsafeRow]) - unsafeRow + queue.add(unsafeProjection(inputRow)) + prunedProjection(inputRow) } val outputRowIterator = evaluate( From 195ba5b48ec7a4c8906253ac1b8ecdaff9e65f9e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 16 Aug 2018 13:41:30 +0000 Subject: [PATCH 08/14] Fix python style --- python/pyspark/sql/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3a69d4e8602d..2211bbd76e72 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5324,6 +5324,7 @@ def test_datasource_with_udf_filter_lit_input(self): finally: shutil.rmtree(path) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) From f64854e266afd096613db291ae725bcddd1b5659 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 23 Aug 2018 13:40:44 +0000 Subject: [PATCH 09/14] wip --- python/pyspark/sql/tests.py | 4 ++++ .../spark/sql/execution/python/EvalPythonExec.scala | 9 +++------ .../apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2211bbd76e72..9d5cb7f692c3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,10 @@ def test_datasource_with_udf_filter_lit_input(self): .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ .load() + datasource_df.show() + datasource_v2_df.show() + + filter1 = udf(lambda: False, 'boolean')() filter2 = udf(lambda x: False, 'boolean')(lit(1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index c6005f1568d4..04c7dfdd4e20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -117,18 +117,15 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } }.toArray }.toArray - - // Project input rows to unsafe row so we can put it in the row queue - val unsafeProjection = UnsafeProjection.create(child.output, child.output) - val prunedProjection = newMutableProjection(allInputs, child.output) + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) }) // Add rows to queue to join later with the result. val projectedRowIter = iter.map { inputRow => - queue.add(unsafeProjection(inputRow)) - prunedProjection(inputRow) + queue.add(inputRow.asInstanceOf[UnsafeRow]) + projection(inputRow) } val outputRowIterator = evaluate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 5edeff553eb1..d2e1434e3d7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -55,6 +55,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() + df.explain() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) From 4e63704c3c4f94b244d56c8774c1b22acc7a1d06 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 23 Aug 2018 15:45:32 +0000 Subject: [PATCH 10/14] Disable data source v2 test for now --- python/pyspark/sql/tests.py | 22 +++++++++---------- .../spark/sql/sources/TableScanSuite.scala | 2 ++ .../sql/sources/v2/DataSourceV2Suite.scala | 3 ++- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9d5cb7f692c3..6336a3087917 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3379,18 +3379,15 @@ def test_datasource_with_udf_filter_lit_input(self): datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ .option('from', 0).option('to', 1).load() - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load() - - datasource_df.show() - datasource_v2_df.show() - + # TODO: Enable data source v2 after SPARK-25213 is fixed + # datasource_v2_df = self.spark.read \ + # .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + # .load() filter1 = udf(lambda: False, 'boolean')() filter2 = udf(lambda x: False, 'boolean')(lit(1)) - for df in [filesource_df, datasource_df, datasource_v2_df]: + for df in [filesource_df, datasource_df]: for f in [filter1, filter2]: result = df.filter(f) self.assertEquals(0, result.count()) @@ -5315,13 +5312,14 @@ def test_datasource_with_udf_filter_lit_input(self): datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ .option('from', 0).option('to', 1).load() - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load() + # TODO: Enable data source v2 after SPARK-25213 is fixed + # datasource_v2_df = self.spark.read \ + # .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + # .load() f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) - for df in [filesource_df, datasource_df, datasource_v2_df]: + for df in [filesource_df, datasource_df]: result = df.filter(f) self.assertEquals(0, result.count()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 17690e3df915..13a126ff963d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index d2e1434e3d7f..dd0a31dd4c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -371,7 +371,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProv } } - +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { class ReadSupport extends SimpleReadSupport { From ca0195e4b3984116bbd3f519ee4b1116567c8607 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 23 Aug 2018 17:36:47 +0000 Subject: [PATCH 11/14] Add require_test_compiled --- python/pyspark/sql/tests.py | 11 +++++++++++ python/pyspark/sql/utils.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6336a3087917..3047efbd6613 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,8 +68,16 @@ # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) +_test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + _test_not_compiled_message = _exception_message(e) + _have_pandas = _pandas_requirement_message is None _have_pyarrow = _pyarrow_requirement_message is None +_test_compiled = _test_not_compiled_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -3368,6 +3376,7 @@ def test_ignore_column_of_all_nulls(self): shutil.rmtree(path) # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) def test_datasource_with_udf_filter_lit_input(self): from pyspark.sql.functions import udf, lit, col @@ -5297,6 +5306,8 @@ def f3(x): self.assertEquals(expected.collect(), df1.collect()) + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) def test_datasource_with_udf_filter_lit_input(self): # Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pandas UDF # This needs to a separate test because Arrow dependency is optional diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index bb9ce02c4b60..46d4b1f2168d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -152,6 +152,22 @@ def require_minimum_pyarrow_version(): "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) +def require_test_compiled(): + """ Raise Exception if test classes are not compiled + """ + import os + try: + spark_home = os.environ['SPARK_HOME'] + except KeyError: + raise RuntimeError('SPARK_HOME is not defined in environment') + + test_class_path = os.path.join( + spark_home, 'sql', 'core', 'target', 'scala-2.11', 'test-classes') + if not os.path.isdir(test_class_path): + raise RuntimeError( + "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path) + + class ForeachBatchFunction(object): """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps From 7076820b3f154bd270daf755959da7bba71b1f59 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 23 Aug 2018 17:39:49 +0000 Subject: [PATCH 12/14] Revert small changes --- .../org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index dd0a31dd4c34..f6c3e0ce82e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -55,7 +55,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() - df.explain() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) From d440cbffee135da42a54da95388b01cf17ab16df Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 27 Aug 2018 20:29:31 +0000 Subject: [PATCH 13/14] Move ExtractPythonUDFs to end of optimize stage --- python/pyspark/sql/tests.py | 20 ++++++------ .../spark/sql/execution/QueryExecution.scala | 1 - .../spark/sql/execution/SparkOptimizer.scala | 5 +-- .../spark/sql/execution/SparkPlanner.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 15 +++++++++ .../python/ArrowEvalPythonExec.scala | 9 +++++- .../python/BatchEvalPythonExec.scala | 7 +++++ .../execution/python/ExtractPythonUDFs.scala | 31 ++++++++----------- 8 files changed, 56 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3047efbd6613..57ca8ffd16ca 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3388,15 +3388,14 @@ def test_datasource_with_udf_filter_lit_input(self): datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ .option('from', 0).option('to', 1).load() - # TODO: Enable data source v2 after SPARK-25213 is fixed - # datasource_v2_df = self.spark.read \ - # .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - # .load() + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load() filter1 = udf(lambda: False, 'boolean')() filter2 = udf(lambda x: False, 'boolean')(lit(1)) - for df in [filesource_df, datasource_df]: + for df in [filesource_df, datasource_df, datasource_v2_df]: for f in [filter1, filter2]: result = df.filter(f) self.assertEquals(0, result.count()) @@ -5309,7 +5308,7 @@ def f3(x): # SPARK-24721 @unittest.skipIf(not _test_compiled, _test_not_compiled_message) def test_datasource_with_udf_filter_lit_input(self): - # Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pandas UDF + # Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pantestdas UDF # This needs to a separate test because Arrow dependency is optional import pandas as pd import numpy as np @@ -5323,14 +5322,13 @@ def test_datasource_with_udf_filter_lit_input(self): datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ .option('from', 0).option('to', 1).load() - # TODO: Enable data source v2 after SPARK-25213 is fixed - # datasource_v2_df = self.spark.read \ - # .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - # .load() + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load() f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) - for df in [filesource_df, datasource_df]: + for df in [filesource_df, datasource_df, datasource_v2_df]: result = df.filter(f) self.assertEquals(0, result.count()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3112b306c365..64f49e2d0d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 969def762405..6c6d344240ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, @@ -31,7 +31,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ - Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Extract Python UDFs", Once, + Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 75f5ec0e253d..2a4a1c8ef343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,6 +36,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4c39990acb62..dbc6db62bd82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf @@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert EvalPython logical operator to physical operator. + */ + object PythonEvals extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ArrowEvalPython(udfs, output, child) => + ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil + case BatchEvalPython(udfs, output, child) => + BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil + case _ => + Nil + } + } + object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0bc21c0986e6..6a03f860f8f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) } /** - * A physical plan that evaluates a [[PythonUDF]], + * A logical plan that evaluates a [[PythonUDF]]. + */ +case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + +/** + * A physical plan that evaluates a [[PythonUDF]]. */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f4d83e8dc7c2..2054c700957e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + /** * A physical plan that evaluates a [[PythonUDF]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 25db06f58597..90b5325919e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,10 +24,8 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec /** @@ -94,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { private type EvalType = Int private type EvalTypeChecker = EvalType => Boolean @@ -133,17 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // SPARK-24721: Ignore Python UDFs in DataSourceScan and DataSourceV2Scan - case plan: DataSourceScanExec => plan - case plan: DataSourceV2ScanExec => plan - case plan: SparkPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case plan: LogicalPlan => extract(plan) } /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - private def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: LogicalPlan): LogicalPlan = { val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) @@ -155,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val prunedChildren = plan.children.map { child => val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq if (allNeededOutput.length != child.output.length) { - ProjectExec(allNeededOutput, child) + Project(allNeededOutput, child) } else { child } @@ -184,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) case _ => throw new AnalysisException( "Expected either Scalar Pandas UDFs or Batched UDFs but got both") @@ -213,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - ProjectExec(plan.output, newPlan) + Project(plan.output, newPlan) } else { newPlan } @@ -222,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Split the original FilterExec to two FilterExecs. Only push down the first few predicates // that are all deterministic. - private def trySplitFilter(plan: SparkPlan): SparkPlan = { + private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { plan match { - case filter: FilterExec => + case filter: Filter => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { - val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) + val newChild = Filter(pushDown.reduceLeft(And), filter.child) + Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } From 2325a4f18a2bc6cc95d96bc5ac6790749b3e927e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 27 Aug 2018 21:16:10 +0000 Subject: [PATCH 14/14] Add more tests; Fix lint; Remove hardcoded scala-2.11 from require_test_compiled --- python/pyspark/sql/tests.py | 64 ++++++++++++++++++++++++++----------- python/pyspark/sql/utils.py | 7 ++-- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 57ca8ffd16ca..81c0af0b3d81 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3377,29 +3377,42 @@ def test_ignore_column_of_all_nulls(self): # SPARK-24721 @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf_filter_lit_input(self): + def test_datasource_with_udf(self): from pyspark.sql.functions import udf, lit, col path = tempfile.mkdtemp() shutil.rmtree(path) + try: self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.csv(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load() + .option('from', 0).option('to', 1).load().toDF('i') datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load() + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) - filter1 = udf(lambda: False, 'boolean')() - filter2 = udf(lambda x: False, 'boolean')(lit(1)) + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [filter1, filter2]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: result = df.filter(f) self.assertEquals(0, result.count()) - finally: shutil.rmtree(path) @@ -5307,8 +5320,8 @@ def f3(x): # SPARK-24721 @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf_filter_lit_input(self): - # Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pantestdas UDF + def test_datasource_with_udf(self): + # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF # This needs to a separate test because Arrow dependency is optional import pandas as pd import numpy as np @@ -5316,22 +5329,37 @@ def test_datasource_with_udf_filter_lit_input(self): path = tempfile.mkdtemp() shutil.rmtree(path) + try: self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.csv(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') datasource_df = self.spark.read \ .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load() + .option('from', 0).option('to', 1).load().toDF('i') datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load() + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) + c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) - f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.filter(f) - self.assertEquals(0, result.count()) + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) finally: shutil.rmtree(path) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 46d4b1f2168d..bdb3a1467f1d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -156,14 +156,17 @@ def require_test_compiled(): """ Raise Exception if test classes are not compiled """ import os + import glob try: spark_home = os.environ['SPARK_HOME'] except KeyError: raise RuntimeError('SPARK_HOME is not defined in environment') test_class_path = os.path.join( - spark_home, 'sql', 'core', 'target', 'scala-2.11', 'test-classes') - if not os.path.isdir(test_class_path): + spark_home, 'sql', 'core', 'target', '*', 'test-classes') + paths = glob.glob(test_class_path) + + if len(paths) == 0: raise RuntimeError( "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)