Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 19 additions & 33 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.functions import UserDefinedFunction, udf
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the import here, as a lof of tests use it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yea. It's okay and I think it's good timing to clean up while we are here, and while it's broken down into multiple test files now.

from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
Expand Down Expand Up @@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):

def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
Expand All @@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
Expand All @@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):

def test_nondeterministic_udf3(self):
# regression test for SPARK-23233
from pyspark.sql.functions import udf
f = udf(lambda x: x)
# Here we cache the JVM UDF instance.
self.spark.range(1).select(f("id"))
Expand All @@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
self.assertFalse(deterministic)

def test_nondeterministic_udf_in_aggregate(self):
from pyspark.sql.functions import udf, sum
from pyspark.sql.functions import sum
import random
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
df = self.spark.range(10)
Expand Down Expand Up @@ -181,7 +178,6 @@ def test_multiple_udfs(self):
self.assertEqual(tuple(row), (6, 5))

def test_udf_in_filter_on_top_of_outer_join(self):
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on='a', how='left_outer')
Expand All @@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):

def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
Expand All @@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):

def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
def test_udf_and_common_filter_in_left_semi_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
# test python udf is not supported in join type besides left_semi and inner join.
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand Down Expand Up @@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):

def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a < 2, BooleanType())
Expand All @@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.functions import col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
Expand All @@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_udf_in_generate(self):
from pyspark.sql.functions import udf, explode
from pyspark.sql.functions import explode
df = self.spark.range(5)
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
row = df.select(explode(f(*df))).groupBy().sum().first()
Expand All @@ -353,7 +343,6 @@ def test_udf_in_generate(self):
self.assertEqual(res[3][1], 1)

def test_udf_with_order_by_and_limit(self):
from pyspark.sql.functions import udf
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
Expand Down Expand Up @@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name

def filename(path):
return path
Expand All @@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
# This is separate of UDFInitializationTests
# to avoid context initialization
# when udf is called

from pyspark.sql.functions import UserDefinedFunction

f = UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand All @@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
)

def test_udf_with_string_return_type(self):
from pyspark.sql.functions import UserDefinedFunction

add_one = UserDefinedFunction(lambda x: x + 1, "integer")
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
make_array = UserDefinedFunction(
Expand All @@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
self.assertTupleEqual(expected, actual)

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())

def test_udf_with_decorator(self):
from pyspark.sql.functions import lit, udf
from pyspark.sql.functions import lit
from pyspark.sql.types import IntegerType, DoubleType

@udf(IntegerType())
Expand Down Expand Up @@ -523,7 +505,6 @@ def as_double(x):
)

def test_udf_wrapper(self):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

def f(x):
Expand Down Expand Up @@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
# SPARK-24721
@unittest.skipIf(not test_compiled, test_not_compiled_message)
def test_datasource_with_udf(self):
from pyspark.sql.functions import udf, lit, col
from pyspark.sql.functions import lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)
Expand Down Expand Up @@ -609,8 +590,6 @@ def test_datasource_with_udf(self):

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
Expand All @@ -632,6 +611,15 @@ def second_udf(x):
data.collect()
self.assertEqual(test_accum.value, 101)

# SPARK-26293
def test_udf_in_subquery(self):
f = udf(lambda x: x, "long")
with self.tempView("v"):
self.spark.range(1).filter(f("id") >= 0).createTempView("v")
sql = self.spark.sql
result = sql("select i from values(0L) as data(i) where i in (select id from v)")
self.assertEqual(result.collect(), [Row(i=0)])


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
Expand All @@ -642,8 +630,6 @@ def tearDown(self):
SparkContext._active_spark_context.stop()

def test_udf_init_shouldnt_initialize_context(self):
from pyspark.sql.functions import UserDefinedFunction

UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
/**
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class ArrowEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a different but related fix, to make the missingAttributes calculated correctly.

}

/**
* A physical plan that evaluates a [[PythonUDF]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class BatchEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule


Expand Down Expand Up @@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan match {
// SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
// eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
case _: Subquery => plan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I found it a bit confusing when two seeming unrelated things are put together (Subquery and ExtractPythonUDFs).

I wonder if it's sufficient to make ExtractPythonUDFs idempotent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's a bit confusing, but that's how Subquery is designed to work. See how RemoveRedundantAliases catches Subquery.

It's sufficient to make ExtractPythonUDFs idempotent, skip Subquery is just for double safe, and may have a little bit perf improvement, since this rule will be run less.

In general, I think we should skip Subquery here. This is why we create Subquery: we expect rules that don't want to be executed on subquery to skip it. I'll check more rules and see if they need to skip Subquery later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If it's common to skip Subquery in other rules, I guess it's ok to put it in here as well. But it would definitely be helpful to establish some kind of guidance, maybe sth like "All optimizer rule should skip Subquery because OptimizeSubqueries will execute them anyway"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have a point here. If subquery will be converted to join, why do we need to optimize subquery ahead?

Anyway, that's something we need to discuss later. cc @dilipbiswal for the subquery question.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is totally ok to skip Subquery for all optimizer rules.

For ExtractPythonUDFs I think it is ok because ExtractPythonUDFs is performed after the rules in RewriteSubquery. So we can skip ExtractPythonUDFs here and extract Python UDF after the subqueries are rewritten into join.

But for the rules which perform before RewriteSubquery, if we skip it on Subquery, we have no chance to do the rules after the subqueries are rewritten into join.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we want to ensure this rule is running once and only once. In the future, if we have another rule/function that calls Optimizer.this.execute(plan), this rule needs to be fixed again... We have a very strong hidden assumption in the implementation. This looks risky in the long term.

The current fix is fine for backporting to 2.4.


case _ => plan transformUp {
// A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
// `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
// extract Python UDFs from them.
case p: BatchEvalPython => p
case p: ArrowEvalPython => p

case plan: LogicalPlan => extract(plan)
}
}

/**
Expand Down