diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index ed298f724d55..12cf8c7de1da 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -23,7 +23,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, Column, Row -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.functions import UserDefinedFunction, udf from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message @@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self): def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() self.assertEqual(udf_random_col.deterministic, False) @@ -113,7 +112,6 @@ def test_nondeterministic_udf(self): def test_nondeterministic_udf2(self): import random - from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) @@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self): def test_nondeterministic_udf3(self): # regression test for SPARK-23233 - from pyspark.sql.functions import udf f = udf(lambda x: x) # Here we cache the JVM UDF instance. self.spark.range(1).select(f("id")) @@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self): self.assertFalse(deterministic) def test_nondeterministic_udf_in_aggregate(self): - from pyspark.sql.functions import udf, sum + from pyspark.sql.functions import sum import random udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() df = self.spark.range(10) @@ -181,7 +178,6 @@ def test_multiple_udfs(self): self.assertEqual(tuple(row), (6, 5)) def test_udf_in_filter_on_top_of_outer_join(self): - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(a=1)]) df = left.join(right, on='a', how='left_outer') @@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self): def test_udf_in_filter_on_top_of_join(self): # regression test for SPARK-18589 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self): def test_udf_in_join_condition(self): # regression test for SPARK-25314 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -211,7 +205,7 @@ def test_udf_in_join_condition(self): def test_udf_in_left_outer_join_condition(self): # regression test for SPARK-26147 - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import col left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a: str(a), StringType()) @@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self): def test_udf_in_left_semi_join_condition(self): # regression test for SPARK-25314 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self): def test_udf_and_common_filter_in_join_condition(self): # regression test for SPARK-25314 # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self): def test_udf_and_common_filter_in_left_semi_join_condition(self): # regression test for SPARK-25314 # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self): def test_udf_not_supported_in_join_condition(self): # regression test for SPARK-25314 # test python udf is not supported in join type besides left_semi and inner join. - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -301,7 +291,7 @@ def test_broadcast_in_udf(self): def test_udf_with_filter_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import col from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a < 2, BooleanType()) @@ -310,7 +300,7 @@ def test_udf_with_filter_function(self): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col, sum + from pyspark.sql.functions import col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) @@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self): self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) def test_udf_in_generate(self): - from pyspark.sql.functions import udf, explode + from pyspark.sql.functions import explode df = self.spark.range(5) f = udf(lambda x: list(range(x)), ArrayType(LongType())) row = df.select(explode(f(*df))).groupBy().sum().first() @@ -353,7 +343,6 @@ def test_udf_in_generate(self): self.assertEqual(res[3][1], 1) def test_udf_with_order_by_and_limit(self): - from pyspark.sql.functions import udf my_copy = udf(lambda x: x, IntegerType()) df = self.spark.range(10).orderBy("id") res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) @@ -394,14 +383,14 @@ def test_non_existed_udaf(self): lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) def test_udf_with_input_file_name(self): - from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.functions import input_file_name sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() self.assertTrue(row[0].find("people1.json") != -1) def test_udf_with_input_file_name_for_hadooprdd(self): - from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.functions import input_file_name def filename(path): return path @@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization # when udf is called - - from pyspark.sql.functions import UserDefinedFunction - f = UserDefinedFunction(lambda x: x, StringType()) self.assertIsNone( @@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self): ) def test_udf_with_string_return_type(self): - from pyspark.sql.functions import UserDefinedFunction - add_one = UserDefinedFunction(lambda x: x + 1, "integer") make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") make_array = UserDefinedFunction( @@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self): self.assertTupleEqual(expected, actual) def test_udf_shouldnt_accept_noncallable_object(self): - from pyspark.sql.functions import UserDefinedFunction - non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) def test_udf_with_decorator(self): - from pyspark.sql.functions import lit, udf + from pyspark.sql.functions import lit from pyspark.sql.types import IntegerType, DoubleType @udf(IntegerType()) @@ -523,7 +505,6 @@ def as_double(x): ) def test_udf_wrapper(self): - from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType def f(x): @@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self): # SPARK-24721 @unittest.skipIf(not test_compiled, test_not_compiled_message) def test_datasource_with_udf(self): - from pyspark.sql.functions import udf, lit, col + from pyspark.sql.functions import lit, col path = tempfile.mkdtemp() shutil.rmtree(path) @@ -609,8 +590,6 @@ def test_datasource_with_udf(self): # SPARK-25591 def test_same_accumulator_in_udfs(self): - from pyspark.sql.functions import udf - data_schema = StructType([StructField("a", IntegerType(), True), StructField("b", IntegerType(), True)]) data = self.spark.createDataFrame([[1, 2]], schema=data_schema) @@ -632,6 +611,15 @@ def second_udf(x): data.collect() self.assertEqual(test_accum.value, 101) + # SPARK-26293 + def test_udf_in_subquery(self): + f = udf(lambda x: x, "long") + with self.tempView("v"): + self.spark.range(1).filter(f("id") >= 0).createTempView("v") + sql = self.spark.sql + result = sql("select i from values(0L) as data(i) where i in (select id from v)") + self.assertEqual(result.collect(), [Row(i=0)]) + class UDFInitializationTests(unittest.TestCase): def tearDown(self): @@ -642,8 +630,6 @@ def tearDown(self): SparkContext._active_spark_context.stop() def test_udf_init_shouldnt_initialize_context(self): - from pyspark.sql.functions import UserDefinedFunction - UserDefinedFunction(lambda x: x, StringType()) self.assertIsNone( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 2b87796dc683..a5203daea9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) /** * A logical plan that evaluates a [[PythonUDF]]. */ -case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) - extends UnaryNode +case class ArrowEvalPython( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) +} /** * A physical plan that evaluates a [[PythonUDF]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index b08b7e60e130..d3736d24e501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * A logical plan that evaluates a [[PythonUDF]] */ -case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) - extends UnaryNode +case class BatchEvalPython( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) +} /** * A physical plan that evaluates a [[PythonUDF]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 90b5325919e9..380c31baa621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case plan: LogicalPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan match { + // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule + // eventually. Here we skip subquery, as Python UDF only needs to be extracted once. + case _: Subquery => plan + + case _ => plan transformUp { + // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and + // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't + // extract Python UDFs from them. + case p: BatchEvalPython => p + case p: ArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } } /**