Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit, input_file_name
from pyspark.sql.functions import UserDefinedFunction, sha2, lit, input_file_name, udf
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException

Expand Down Expand Up @@ -457,7 +457,6 @@ def test_udf_registration_return_type_not_none(self):

def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
Expand All @@ -468,7 +467,6 @@ def test_nondeterministic_udf(self):

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
Expand All @@ -487,7 +485,6 @@ def test_nondeterministic_udf2(self):

def test_nondeterministic_udf3(self):
# regression test for SPARK-23233
from pyspark.sql.functions import udf
f = udf(lambda x: x)
# Here we cache the JVM UDF instance.
self.spark.range(1).select(f("id"))
Expand All @@ -499,7 +496,7 @@ def test_nondeterministic_udf3(self):
self.assertFalse(deterministic)

def test_nondeterministic_udf_in_aggregate(self):
from pyspark.sql.functions import udf, sum
from pyspark.sql.functions import sum
import random
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
df = self.spark.range(10)
Expand Down Expand Up @@ -536,7 +533,6 @@ def test_multiple_udfs(self):
self.assertEqual(tuple(row), (6, 5))

def test_udf_in_filter_on_top_of_outer_join(self):
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on='a', how='left_outer')
Expand All @@ -545,7 +541,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):

def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -554,7 +549,6 @@ def test_udf_in_filter_on_top_of_join(self):

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -566,7 +560,7 @@ def test_udf_in_join_condition(self):

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
Expand All @@ -579,7 +573,6 @@ def test_udf_in_left_outer_join_condition(self):
def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -590,7 +583,6 @@ def test_udf_and_common_filter_in_join_condition(self):
def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
# test python udf is not supported in join type except inner join.
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand Down Expand Up @@ -632,7 +624,7 @@ def test_broadcast_in_udf(self):

def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a < 2, BooleanType())
Expand All @@ -641,7 +633,7 @@ def test_udf_with_filter_function(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.functions import col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
Expand All @@ -657,7 +649,7 @@ def test_udf_with_aggregate_function(self):
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_udf_in_generate(self):
from pyspark.sql.functions import udf, explode
from pyspark.sql.functions import explode
df = self.spark.range(5)
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
row = df.select(explode(f(*df))).groupBy().sum().first()
Expand All @@ -684,7 +676,6 @@ def test_udf_in_generate(self):
self.assertEqual(res[3][1], 1)

def test_udf_with_order_by_and_limit(self):
from pyspark.sql.functions import udf
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
Expand Down Expand Up @@ -803,14 +794,14 @@ def test_read_multiple_orc_file(self):
self.assertEqual(2, df.count())

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name

def filename(path):
return path
Expand Down Expand Up @@ -859,9 +850,6 @@ def test_udf_defers_judf_initialization(self):
# This is separate of UDFInitializationTests
# to avoid context initialization
# when udf is called

from pyspark.sql.functions import UserDefinedFunction

f = UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand All @@ -877,8 +865,6 @@ def test_udf_defers_judf_initialization(self):
)

def test_udf_with_string_return_type(self):
from pyspark.sql.functions import UserDefinedFunction

add_one = UserDefinedFunction(lambda x: x + 1, "integer")
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
make_array = UserDefinedFunction(
Expand All @@ -892,13 +878,11 @@ def test_udf_with_string_return_type(self):
self.assertTupleEqual(expected, actual)

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())

def test_udf_with_decorator(self):
from pyspark.sql.functions import lit, udf
from pyspark.sql.functions import lit
from pyspark.sql.types import IntegerType, DoubleType

@udf(IntegerType())
Expand Down Expand Up @@ -955,7 +939,6 @@ def as_double(x):
)

def test_udf_wrapper(self):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

def f(x):
Expand Down Expand Up @@ -991,7 +974,7 @@ def __call__(self, x):
self.assertEqual(return_type, f_.returnType)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.functions import to_json
from pyspark.sql.column import _to_java_column

self.assertTrue("Column" in _to_java_column("a").getClass().toString())
Expand Down Expand Up @@ -3459,7 +3442,7 @@ def test_ignore_column_of_all_nulls(self):
# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf(self):
from pyspark.sql.functions import udf, lit, col
from pyspark.sql.functions import lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)
Expand Down Expand Up @@ -3571,8 +3554,6 @@ def test_repr_behaviors(self):

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
Expand All @@ -3594,6 +3575,17 @@ def second_udf(x):
data.collect()
self.assertEqual(test_accum.value, 101)

# SPARK-26293
def test_udf_in_subquery(self):
f = udf(lambda x: x, "long")
try:
self.spark.range(1).filter(f("id") >= 0).createTempView("v")
sql = self.spark.sql
result = sql("select i from values(0L) as data(i) where i in (select id from v)")
self.assertEqual(result.collect(), [Row(i=0)])
finally:
self.spark.catalog.dropTempView("v")


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down Expand Up @@ -3771,8 +3763,6 @@ def tearDown(self):
SparkContext._active_spark_context.stop()

def test_udf_init_shouldnt_initialize_context(self):
from pyspark.sql.functions import UserDefinedFunction

UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
/**
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class ArrowEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class BatchEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule


Expand Down Expand Up @@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan match {
// SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
// eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
case _: Subquery => plan

case _ => plan transformUp {
// A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
// `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
// extract Python UDFs from them.
case p: BatchEvalPython => p
case p: ArrowEvalPython => p

case plan: LogicalPlan => extract(plan)
}
}

/**
Expand Down