Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
"""Registers a Python function (including lambda function) or a wrapped/native UDF
Copy link
Contributor

@cloud-fan cloud-fan Jan 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really confusing when reading this document, it would be much more clear to me if we can just say

Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` as a UDF

This wrapping logic was added in #16534 , is it really worth?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It indeed added some complexity. However, I believe nothing is blocked by #16534 now if I understand correctly.

The changes #16534 is quite nice because IMHO Python guys probably use help() and dir() more frequently then reading the API doc in the website. For the set of UDFs are provided as a library, I think that's quite worth to keep.

How about leaving this wrapper logic as is for now and then we bring this discussion back when actually something is blocked (or being too complicated) by this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea just in case it helps:

Registers a Python function as a UDF or a user defined function.

Copy link
Member

@HyukjinKwon HyukjinKwon Jan 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, to be honest, I remember I gave several quick tries to get rid of the wrapper but keep the docstring correctly at that time but I failed to make a good alternative.

Might be good to try if there is a clever way to get rid of the wrapper but keep the doc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

so it can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix the doc for this too. It says :param f: python function but we could describe that it takes Python native function, wrapped function and UserDefinedFunction too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)

# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this one made me to suggest wrapper._unwrapped = lambda: self way.

So, here this can be wrapped function or UserDefinedFunction and I thought it's not quite clear what we expect here by hasattr(f, 'asNondeterministic').

Could we at least leave come comments saying that this can be both wrapped function for UserDefinedFunction and UserDefinedFunction itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add a comment.

udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ueshin @icexelloss , shall we support register pandas UDF here too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we can support it by just changing evalType=PythonEvalType.SQL_BATCHED_UDF to evalType=f.evalType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 but I think there's no way to use a group map UDF in SQL syntax if I understood correctly. I think we can safely fail fast for now as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will support the pandas UDF as a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 too

deterministic=f.deterministic)
else:
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()

Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
"""Registers a Python function (including lambda function) or a wrapped/native UDF
so it can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)

Expand Down
49 changes: 35 additions & 14 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,41 @@ def test_udf2(self):
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_udf3(self):
twoargs = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
self.assertEqual(twoargs.deterministic, True)
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
self.assertEqual(random_udf1.deterministic, False)
[row] = self.spark.sql("SELECT randInt()").collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf1()).collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf()).collect()
self.assertEqual(row[0], 6)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this tests there or make this separate from test_non_deterministic_udf? Adding comments is also fine to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add a comment.

pydoc.render_doc(random_udf)
pydoc.render_doc(random_udf1)

def test_chained_udf(self):
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.spark.sql("SELECT double(1)").collect()
Expand Down Expand Up @@ -435,15 +470,6 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down Expand Up @@ -567,15 +593,13 @@ def test_read_multiple_orc_file(self):

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType

def filename(path):
return path
Expand Down Expand Up @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self):

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import StringType

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
Expand Down Expand Up @@ -1299,7 +1322,6 @@ def test_between_function(self):
df.filter(df.a.between(df.b, df.c)).collect())

def test_struct_type(self):
from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
Expand Down Expand Up @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self):
_parse_datatype_string("a INT, c DOUBLE"))

def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
Expand Down
21 changes: 14 additions & 7 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
)

# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
return udf_obj._wrapped()


Expand All @@ -67,8 +68,10 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func,
returnType=StringType(), name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF):
returnType=StringType(),
name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
Expand All @@ -92,7 +95,7 @@ def __init__(self, func,
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
self._deterministic = True
self.deterministic = deterministic

@property
def returnType(self):
Expand Down Expand Up @@ -130,14 +133,17 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf

def __call__(self, *cols):
judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

# This function is for improving the online help system in the interactive interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
# argument annotation. (See: SPARK-19161)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can put this in the docstring of _wrapped between L148 and 150L.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not want to expose these comments to the doc.

def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
Expand All @@ -162,7 +168,8 @@ def wrapper(*args):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
wrapper.asNondeterministic = self.asNondeterministic
wrapper.deterministic = self.deterministic
wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do:

       wrapper.asNondeterministic = functools.wraps(
           self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())

So that it can produce a proper pydoc when we do help(udf(lambda: 1, "integer").asNondeterministic) (not help(udf(lambda: 1, "integer").asNondeterministic()).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to know the difference

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this unchanged. Maybe you can submit a follow-up PR to address it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. Will give a try within the following week tho ...


return wrapper

Expand All @@ -172,5 +179,5 @@ def asNondeterministic(self):

.. versionadded:: 2.3
"""
self._deterministic = False
self.deterministic = False
Copy link
Member

@HyukjinKwon HyukjinKwon Jan 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call it udfDeterministic to be consistent with Scala side?

The opposite works fine to me too if that's possible in any way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deterministic is used in UserDefinedFunction.scala. Users can use it to check whether this UDF is deterministic or not.

return self