-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22106][PYSPARK][SQL] Disable 0-parameter pandas_udf and add doctests #19325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c0eec8d
7b0da10
56a8409
6dc89b0
6fc639a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2127,6 +2127,10 @@ def wrapper(*args): | |
| def _create_udf(f, returnType, vectorized): | ||
|
|
||
| def _udf(f, returnType=StringType(), vectorized=vectorized): | ||
| if vectorized: | ||
| import inspect | ||
| if len(inspect.getargspec(f).args) == 0: | ||
| raise NotImplementedError("0-parameter pandas_udfs are not currently supported") | ||
| udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) | ||
| return udf_obj._wrapped() | ||
|
|
||
|
|
@@ -2183,14 +2187,28 @@ def pandas_udf(f=None, returnType=StringType()): | |
| :param f: python function if used as a standalone function | ||
| :param returnType: a :class:`pyspark.sql.types.DataType` object | ||
|
|
||
| # TODO: doctest | ||
| >>> from pyspark.sql.types import IntegerType, StringType | ||
| >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) | ||
| >>> @pandas_udf(returnType=StringType()) | ||
| ... def to_upper(s): | ||
| ... return s.str.upper() | ||
| ... | ||
| >>> @pandas_udf(returnType="integer") | ||
| ... def add_one(x): | ||
| ... return x + 1 | ||
| ... | ||
| >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) | ||
| >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ | ||
| ... .show() # doctest: +SKIP | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems we don't skip it actually?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks actually we do :). Let me test this one for sure in my local before merging it, (I have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. It is. :)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just double checked it passes Also, checked without diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 63e9a830bbc..3265ecc974b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2199,7 +2199,7 @@ def pandas_udf(f=None, returnType=StringType()):
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
- ... .show() # doctest: +SKIP
+ ... .show()
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (D'oh, not a big deal but two spaces before inline comments..)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @HyukjinKwon! Sorry, I didn't notice this :( I'll make a note to fix that spacing on a related change. |
||
| +----------+--------------+------------+ | ||
| |slen(name)|to_upper(name)|add_one(age)| | ||
| +----------+--------------+------------+ | ||
| | 8| JOHN DOE| 22| | ||
| +----------+--------------+------------+ | ||
| """ | ||
| import inspect | ||
| # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder | ||
| if inspect.getargspec(f).keywords is None: | ||
| return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True) | ||
| else: | ||
| return _create_udf(f, returnType=returnType, vectorized=True) | ||
| wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True) | ||
|
|
||
| return wrapped_udf | ||
|
|
||
|
|
||
| blacklist = ['map', 'since', 'ignore_unicode_prefix'] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3256,11 +3256,20 @@ def test_vectorized_udf_null_string(self): | |
|
|
||
| def test_vectorized_udf_zero_parameter(self): | ||
| from pyspark.sql.functions import pandas_udf | ||
| import pandas as pd | ||
| df = self.spark.range(10) | ||
| f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType()) | ||
| res = df.select(f0()) | ||
| self.assertEquals(df.select(lit(1)).collect(), res.collect()) | ||
| error_str = '0-parameter pandas_udfs.*not.*supported' | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(NotImplementedError, error_str): | ||
| pandas_udf(lambda: 1, LongType()) | ||
|
|
||
| with self.assertRaisesRegexp(NotImplementedError, error_str): | ||
| @pandas_udf | ||
| def zero_no_type(): | ||
| return 1 | ||
|
|
||
| with self.assertRaisesRegexp(NotImplementedError, error_str): | ||
| @pandas_udf(LongType()) | ||
| def zero_with_type(): | ||
| return 1 | ||
|
|
||
| def test_vectorized_udf_datatype_string(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
|
|
@@ -3308,12 +3317,12 @@ def test_vectorized_udf_invalid_length(self): | |
| from pyspark.sql.functions import pandas_udf, col | ||
| import pandas as pd | ||
| df = self.spark.range(10) | ||
| raise_exception = pandas_udf(lambda: pd.Series(1), LongType()) | ||
| raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'Result vector from pandas_udf was not the required length'): | ||
| df.select(raise_exception()).collect() | ||
| df.select(raise_exception(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_mix_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, udf, col | ||
|
|
@@ -3328,22 +3337,44 @@ def test_vectorized_udf_mix_udf(self): | |
|
|
||
| def test_vectorized_udf_chained(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10).toDF('x') | ||
| df = self.spark.range(10) | ||
| f = pandas_udf(lambda x: x + 1, LongType()) | ||
| g = pandas_udf(lambda x: x - 1, LongType()) | ||
| res = df.select(g(f(col('x')))) | ||
| res = df.select(g(f(col('id')))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_wrong_return_type(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10).toDF('x') | ||
| df = self.spark.range(10) | ||
| f = pandas_udf(lambda x: x * 1.0, StringType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'Invalid.*type.*string'): | ||
| df.select(f(col('x'))).collect() | ||
| with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): | ||
| df.select(f(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_return_scalar(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10) | ||
| f = pandas_udf(lambda x: 1.0, DoubleType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): | ||
| df.select(f(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_decorator(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10) | ||
|
|
||
| @pandas_udf(returnType=LongType()) | ||
| def identity(x): | ||
| return x | ||
| res = df.select(identity(col('id'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_empty_partition(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I miss something, but what this test is intended to test?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh. I see. One partition is empty and it is related to the added stuff in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, an empty partition leads to an empty iterator, so this is to make sure it can handle that. |
||
| f = pandas_udf(lambda x: x, LongType()) | ||
| res = df.select(f(col('id'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we installed pyarrow on Jenkins? The failed test complains
ImportError: No module named pyarrow.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just do
# doctest: +SKIPmaybe.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I thought that the Jenkins environment for unit tests would be the same for doctests and have pyarrow installed. @holdenk or @shaneknapp do you know if that is the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding @JoshRosen too.
the doc building node (amp-jenkins-worker-01) doesn't have arrow installed for the default conda python 2.7 environment. for the python 3 environment, we're running arrow 0.4.0.
i looked at the script and it seems to be agnostic to python 2 vs 3... once i know which version of python we'll be running i can make sure that the version of arrow installed is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, thanks @shaneknapp!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm.. but shouldn't we skip those doctests because they are not hard dependencies anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, I see that
toPandas()also skips doctests. I'll skip this now and can always enable later if we decide differently. @shaneknapp , looks like we will hold off on this so no need to do anything to Jenkins I believe, sorry to bug you.