Skip to content

Commit 3a80cc5

Browse files
HyukjinKwonueshin
authored andcommitted
[SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark
## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes #20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon <[email protected]> Closes #20288 from HyukjinKwon/deprecate-udf. (cherry picked from commit 39d244d) Signed-off-by: Takuya UESHIN <[email protected]>
1 parent f2688ef commit 3a80cc5

File tree

8 files changed

+234
-210
lines changed

8 files changed

+234
-210
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def __hash__(self):
400400
"pyspark.sql.functions",
401401
"pyspark.sql.readwriter",
402402
"pyspark.sql.streaming",
403+
"pyspark.sql.udf",
403404
"pyspark.sql.window",
404405
"pyspark.sql.tests",
405406
]

python/pyspark/sql/catalog.py

Lines changed: 8 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName):
224224
"""
225225
self._jcatalog.dropGlobalTempView(viewName)
226226

227-
@ignore_unicode_prefix
228227
@since(2.0)
229228
def registerFunction(self, name, f, returnType=None):
230-
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
231-
as a UDF. The registered UDF can be used in SQL statements.
232-
233-
:func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
234-
235-
In addition to a name and the function itself, `returnType` can be optionally specified.
236-
1) When f is a Python function, `returnType` defaults to a string. The produced object must
237-
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
238-
type of the given UDF as the return type of the registered UDF. The input parameter
239-
`returnType` is None by default. If given by users, the value must be None.
240-
241-
:param name: name of the UDF in SQL statements.
242-
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
243-
row-at-a-time or vectorized.
244-
:param returnType: the return type of the registered UDF.
245-
:return: a wrapped/native :class:`UserDefinedFunction`
246-
247-
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
248-
>>> spark.sql("SELECT stringLengthString('test')").collect()
249-
[Row(stringLengthString(test)=u'4')]
250-
251-
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
252-
[Row(stringLengthString(text)=u'3')]
253-
254-
>>> from pyspark.sql.types import IntegerType
255-
>>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
256-
>>> spark.sql("SELECT stringLengthInt('test')").collect()
257-
[Row(stringLengthInt(test)=4)]
258-
259-
>>> from pyspark.sql.types import IntegerType
260-
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
261-
>>> spark.sql("SELECT stringLengthInt('test')").collect()
262-
[Row(stringLengthInt(test)=4)]
263-
264-
>>> from pyspark.sql.types import IntegerType
265-
>>> from pyspark.sql.functions import udf
266-
>>> slen = udf(lambda s: len(s), IntegerType())
267-
>>> _ = spark.udf.register("slen", slen)
268-
>>> spark.sql("SELECT slen('test')").collect()
269-
[Row(slen(test)=4)]
270-
271-
>>> import random
272-
>>> from pyspark.sql.functions import udf
273-
>>> from pyspark.sql.types import IntegerType
274-
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
275-
>>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
276-
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
277-
[Row(random_udf()=82)]
278-
>>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP
279-
[Row(<lambda>()=26)]
280-
281-
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
282-
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
283-
... def add_one(x):
284-
... return x + 1
285-
...
286-
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
287-
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
288-
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
289-
"""
229+
"""An alias for :func:`spark.udf.register`.
230+
See :meth:`pyspark.sql.UDFRegistration.register`.
290231
291-
# This is to check whether the input function is a wrapped/native UserDefinedFunction
292-
if hasattr(f, 'asNondeterministic'):
293-
if returnType is not None:
294-
raise TypeError(
295-
"Invalid returnType: None is expected when f is a UserDefinedFunction, "
296-
"but got %s." % returnType)
297-
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
298-
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
299-
raise ValueError(
300-
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
301-
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
302-
evalType=f.evalType,
303-
deterministic=f.deterministic)
304-
return_udf = f
305-
else:
306-
if returnType is None:
307-
returnType = StringType()
308-
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
309-
evalType=PythonEvalType.SQL_BATCHED_UDF)
310-
return_udf = register_udf._wrapped()
311-
self._jsparkSession.udf().registerPython(name, register_udf._judf)
312-
return return_udf
232+
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
233+
"""
234+
warnings.warn(
235+
"Deprecated in 2.3.0. Use spark.udf.register instead.",
236+
DeprecationWarning)
237+
return self._sparkSession.udf.register(name, f, returnType)
313238

314239
@since(2.0)
315240
def isCached(self, tableName):

python/pyspark/sql/context.py

Lines changed: 18 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from pyspark.sql.readwriter import DataFrameReader
3030
from pyspark.sql.streaming import DataStreamReader
3131
from pyspark.sql.types import IntegerType, Row, StringType
32+
from pyspark.sql.udf import UDFRegistration
3233
from pyspark.sql.utils import install_exception_handler
3334

34-
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
35+
__all__ = ["SQLContext", "HiveContext"]
3536

3637

3738
class SQLContext(object):
@@ -147,7 +148,7 @@ def udf(self):
147148
148149
:return: :class:`UDFRegistration`
149150
"""
150-
return UDFRegistration(self)
151+
return self.sparkSession.udf
151152

152153
@since(1.4)
153154
def range(self, start, end=None, step=1, numPartitions=None):
@@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None):
172173
"""
173174
return self.sparkSession.range(start, end, step, numPartitions)
174175

175-
@ignore_unicode_prefix
176176
@since(1.2)
177177
def registerFunction(self, name, f, returnType=None):
178-
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
179-
as a UDF. The registered UDF can be used in SQL statements.
180-
181-
:func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
182-
183-
In addition to a name and the function itself, `returnType` can be optionally specified.
184-
1) When f is a Python function, `returnType` defaults to a string. The produced object must
185-
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
186-
type of the given UDF as the return type of the registered UDF. The input parameter
187-
`returnType` is None by default. If given by users, the value must be None.
188-
189-
:param name: name of the UDF in SQL statements.
190-
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
191-
row-at-a-time or vectorized.
192-
:param returnType: the return type of the registered UDF.
193-
:return: a wrapped/native :class:`UserDefinedFunction`
194-
195-
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
196-
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
197-
[Row(stringLengthString(test)=u'4')]
198-
199-
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
200-
[Row(stringLengthString(text)=u'3')]
201-
202-
>>> from pyspark.sql.types import IntegerType
203-
>>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
204-
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
205-
[Row(stringLengthInt(test)=4)]
206-
207-
>>> from pyspark.sql.types import IntegerType
208-
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
209-
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
210-
[Row(stringLengthInt(test)=4)]
211-
212-
>>> from pyspark.sql.types import IntegerType
213-
>>> from pyspark.sql.functions import udf
214-
>>> slen = udf(lambda s: len(s), IntegerType())
215-
>>> _ = sqlContext.udf.register("slen", slen)
216-
>>> sqlContext.sql("SELECT slen('test')").collect()
217-
[Row(slen(test)=4)]
218-
219-
>>> import random
220-
>>> from pyspark.sql.functions import udf
221-
>>> from pyspark.sql.types import IntegerType
222-
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
223-
>>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
224-
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
225-
[Row(random_udf()=82)]
226-
>>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP
227-
[Row(<lambda>()=26)]
228-
229-
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
230-
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
231-
... def add_one(x):
232-
... return x + 1
233-
...
234-
>>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP
235-
>>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
236-
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
178+
"""An alias for :func:`spark.udf.register`.
179+
See :meth:`pyspark.sql.UDFRegistration.register`.
180+
181+
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
237182
"""
238-
return self.sparkSession.catalog.registerFunction(name, f, returnType)
183+
warnings.warn(
184+
"Deprecated in 2.3.0. Use spark.udf.register instead.",
185+
DeprecationWarning)
186+
return self.sparkSession.udf.register(name, f, returnType)
239187

240-
@ignore_unicode_prefix
241188
@since(2.1)
242189
def registerJavaFunction(self, name, javaClassName, returnType=None):
243-
"""Register a java UDF so it can be used in SQL statements.
244-
245-
In addition to a name and the function itself, the return type can be optionally specified.
246-
When the return type is not specified we would infer it via reflection.
247-
:param name: name of the UDF
248-
:param javaClassName: fully qualified name of java class
249-
:param returnType: a :class:`pyspark.sql.types.DataType` object
250-
251-
>>> sqlContext.registerJavaFunction("javaStringLength",
252-
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
253-
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
254-
[Row(UDF:javaStringLength(test)=4)]
255-
>>> sqlContext.registerJavaFunction("javaStringLength2",
256-
... "test.org.apache.spark.sql.JavaStringLength")
257-
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
258-
[Row(UDF:javaStringLength2(test)=4)]
190+
"""An alias for :func:`spark.udf.registerJavaFunction`.
191+
See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
259192
193+
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
260194
"""
261-
jdt = None
262-
if returnType is not None:
263-
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
264-
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
265-
266-
@ignore_unicode_prefix
267-
@since(2.3)
268-
def registerJavaUDAF(self, name, javaClassName):
269-
"""Register a java UDAF so it can be used in SQL statements.
270-
271-
:param name: name of the UDAF
272-
:param javaClassName: fully qualified name of java class
273-
274-
>>> sqlContext.registerJavaUDAF("javaUDAF",
275-
... "test.org.apache.spark.sql.MyDoubleAvg")
276-
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
277-
>>> df.registerTempTable("df")
278-
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
279-
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
280-
"""
281-
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
195+
warnings.warn(
196+
"Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
197+
DeprecationWarning)
198+
return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
282199

283200
# TODO(andrew): delete this once we refactor things to take in SparkSession
284201
def _inferSchema(self, rdd, samplingRatio=None):
@@ -590,24 +507,6 @@ def refreshTable(self, tableName):
590507
self._ssql_ctx.refreshTable(tableName)
591508

592509

593-
class UDFRegistration(object):
594-
"""Wrapper for user-defined function registration."""
595-
596-
def __init__(self, sqlContext):
597-
self.sqlContext = sqlContext
598-
599-
def register(self, name, f, returnType=None):
600-
return self.sqlContext.registerFunction(name, f, returnType)
601-
602-
def registerJavaFunction(self, name, javaClassName, returnType=None):
603-
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
604-
605-
def registerJavaUDAF(self, name, javaClassName):
606-
self.sqlContext.registerJavaUDAF(name, javaClassName)
607-
608-
register.__doc__ = SQLContext.registerFunction.__doc__
609-
610-
611510
def _test():
612511
import os
613512
import doctest

python/pyspark/sql/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()):
21032103
>>> import random
21042104
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
21052105
2106-
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
2106+
.. note:: The user-defined functions do not support conditional expressions or short circuiting
21072107
in boolean expressions and it ends up with being executed all internally. If the functions
21082108
can fail on special rows, the workaround is to incorporate the condition into the functions.
21092109
@@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
22312231
... return pd.Series(np.random.randn(len(v))
22322232
>>> random = random.asNondeterministic() # doctest: +SKIP
22332233
2234-
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
2234+
.. note:: The user-defined functions do not support conditional expressions or short circuiting
22352235
in boolean expressions and it ends up with being executed all internally. If the functions
22362236
can fail on special rows, the workaround is to incorporate the condition into the functions.
22372237
"""

python/pyspark/sql/group.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def apply(self, udf):
212212
This function does not support partial aggregation, and requires shuffling all the data in
213213
the :class:`DataFrame`.
214214
215-
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
215+
:param udf: a group map user-defined function returned by
216+
:meth:`pyspark.sql.functions.pandas_udf`.
216217
217218
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
218219
>>> df = spark.createDataFrame(

python/pyspark/sql/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from pyspark import since
3131
from pyspark.rdd import RDD, ignore_unicode_prefix
32-
from pyspark.sql.catalog import Catalog
3332
from pyspark.sql.conf import RuntimeConfig
3433
from pyspark.sql.dataframe import DataFrame
3534
from pyspark.sql.readwriter import DataFrameReader
@@ -280,6 +279,7 @@ def catalog(self):
280279
281280
:return: :class:`Catalog`
282281
"""
282+
from pyspark.sql.catalog import Catalog
283283
if not hasattr(self, "_catalog"):
284284
self._catalog = Catalog(self)
285285
return self._catalog
@@ -291,8 +291,8 @@ def udf(self):
291291
292292
:return: :class:`UDFRegistration`
293293
"""
294-
from pyspark.sql.context import UDFRegistration
295-
return UDFRegistration(self._wrapped)
294+
from pyspark.sql.udf import UDFRegistration
295+
return UDFRegistration(self)
296296

297297
@since(2.0)
298298
def range(self, start, end=None, step=1, numPartitions=None):

python/pyspark/sql/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ def test_udf(self):
372372
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
373373
self.assertEqual(row[0], 5)
374374

375+
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
376+
sqlContext = self.spark._wrapped
377+
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
378+
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
379+
self.assertEqual(row[0], 4)
380+
375381
def test_udf2(self):
376382
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
377383
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
@@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self):
577583
df.select(add_three("id").alias("plus_three")).collect()
578584
)
579585

586+
# This is to check if a 'SQLContext.udf' can call its alias.
587+
sqlContext = self.spark._wrapped
588+
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
589+
590+
self.assertListEqual(
591+
df.selectExpr("add_four(id) AS plus_four").collect(),
592+
df.select(add_four("id").alias("plus_four")).collect()
593+
)
594+
580595
def test_non_existed_udf(self):
581596
spark = self.spark
582597
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
583598
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
584599

600+
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
601+
sqlContext = spark._wrapped
602+
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
603+
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
604+
585605
def test_non_existed_udaf(self):
586606
spark = self.spark
587607
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",

0 commit comments

Comments
 (0)