-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP][SPARK-21190][SQL][PYTHON] Vectorized UDFs in Python #19147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a2a3f82
a1e4f62
84d2767
1db6cb5
3a0d4a6
2f929d8
dbc6dd2
803054e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,20 @@ | |
| if sys.version < "3": | ||
| from itertools import imap as map | ||
|
|
||
| _have_pandas = False | ||
| try: | ||
| import pandas | ||
| _have_pandas = True | ||
| except: | ||
| pass | ||
|
|
||
| _have_arrow = False | ||
| try: | ||
| import pyarrow | ||
| _have_arrow = True | ||
| except: | ||
| pass | ||
|
|
||
| from pyspark import since, SparkContext | ||
| from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix | ||
| from pyspark.serializers import PickleSerializer, AutoBatchedSerializer | ||
|
|
@@ -2032,7 +2046,7 @@ class UserDefinedFunction(object): | |
|
|
||
| .. versionadded:: 1.3 | ||
| """ | ||
| def __init__(self, func, returnType, name=None): | ||
| def __init__(self, func, returnType, name=None, vectorized=False): | ||
| if not callable(func): | ||
| raise TypeError( | ||
| "Not a function or callable (__call__ is not defined): " | ||
|
|
@@ -2046,6 +2060,7 @@ def __init__(self, func, returnType, name=None): | |
| self._name = name or ( | ||
| func.__name__ if hasattr(func, '__name__') | ||
| else func.__class__.__name__) | ||
| self._vectorized = vectorized | ||
|
|
||
| @property | ||
| def returnType(self): | ||
|
|
@@ -2077,7 +2092,7 @@ def _create_judf(self): | |
| wrapped_func = _wrap_function(sc, self.func, self.returnType) | ||
| jdt = spark._jsparkSession.parseDataType(self.returnType.json()) | ||
| judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( | ||
| self._name, wrapped_func, jdt) | ||
| self._name, wrapped_func, jdt, self._vectorized) | ||
| return judf | ||
|
|
||
| def __call__(self, *cols): | ||
|
|
@@ -2111,6 +2126,53 @@ def wrapper(*args): | |
| return wrapper | ||
|
|
||
|
|
||
| def _udf(f, returnType, vectorized): | ||
| udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) | ||
| return udf_obj._wrapped() | ||
|
|
||
|
|
||
| if _have_pandas and _have_arrow: | ||
|
|
||
| @since(2.3) | ||
| def pandas_udf(f=None, returnType=StringType()): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of hiding |
||
| """ | ||
| Creates a :class:`Column` expression representing a vectorized user defined function (UDF). | ||
|
|
||
| .. note:: The vectorized user-defined functions must be deterministic. Due to optimization, | ||
| duplicate invocations may be eliminated or the function may even be invoked more times | ||
| than it is present in the query. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we explain more about what the vectorized UDF is and its expected input parameters and outputs? |
||
|
|
||
| :param f: python function if used as a standalone function | ||
| :param returnType: a :class:`pyspark.sql.types.DataType` object | ||
|
|
||
| >>> from pyspark.sql.types import LongType | ||
| >>> add = pandas_udf(lambda x, y: x + y, LongType()) | ||
| >>> @pandas_udf(returnType=LongType()) | ||
| ... def mul(x, y): | ||
| ... return x * y | ||
| ... | ||
| >>> import pandas as pd | ||
| >>> ones = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) | ||
|
|
||
| >>> df = spark.createDataFrame([(1, 2), (3, 4)], ("a", "b")) | ||
| >>> df.select(add("a", "b").alias("add(a, b)"), mul("a", "b"), ones().alias("ones")).show() | ||
| +---------+---------+----+ | ||
| |add(a, b)|mul(a, b)|ones| | ||
| +---------+---------+----+ | ||
| | 3| 2| 1| | ||
| | 7| 12| 1| | ||
| +---------+---------+----+ | ||
| """ | ||
| # decorator @pandas_udf, @pandas_udf() or @pandas_udf(dataType()) | ||
| if f is None or isinstance(f, (str, DataType)): | ||
| # If DataType has been passed as a positional argument | ||
| # for decorator use it as a returnType | ||
| return_type = f or returnType | ||
| return functools.partial(_udf, returnType=return_type, vectorized=True) | ||
| else: | ||
| return _udf(f=f, returnType=returnType, vectorized=True) | ||
|
|
||
|
|
||
| @since(1.3) | ||
| def udf(f=None, returnType=StringType()): | ||
| """Creates a :class:`Column` expression representing a user defined function (UDF). | ||
|
|
@@ -2142,18 +2204,14 @@ def udf(f=None, returnType=StringType()): | |
| | 8| JOHN DOE| 22| | ||
| +----------+--------------+------------+ | ||
| """ | ||
| def _udf(f, returnType=StringType()): | ||
| udf_obj = UserDefinedFunction(f, returnType) | ||
| return udf_obj._wrapped() | ||
|
|
||
| # decorator @udf, @udf() or @udf(dataType()) | ||
| if f is None or isinstance(f, (str, DataType)): | ||
| # If DataType has been passed as a positional argument | ||
| # for decorator use it as a returnType | ||
| return_type = f or returnType | ||
| return functools.partial(_udf, returnType=return_type) | ||
| return functools.partial(_udf, returnType=return_type, vectorized=False) | ||
| else: | ||
| return _udf(f=f, returnType=returnType) | ||
| return _udf(f=f, returnType=returnType, vectorized=False) | ||
|
|
||
|
|
||
| blacklist = ['map', 'since', 'ignore_unicode_prefix'] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,11 +63,10 @@ | |
| from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings | ||
| from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings | ||
| from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests | ||
| from pyspark.sql.functions import UserDefinedFunction, sha2, lit | ||
| from pyspark.sql.functions import UserDefinedFunction, sha2, lit, col, expr | ||
| from pyspark.sql.window import Window | ||
| from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException | ||
|
|
||
|
|
||
| _have_arrow = False | ||
| try: | ||
| import pyarrow | ||
|
|
@@ -76,6 +75,9 @@ | |
| # No Arrow, but that's okay, we'll skip those tests | ||
| pass | ||
|
|
||
| if _have_pandas and _have_arrow: | ||
| from pyspark.sql.functions import pandas_udf | ||
|
|
||
|
|
||
| class UTCOffsetTimezone(datetime.tzinfo): | ||
| """ | ||
|
|
@@ -3122,6 +3124,156 @@ def test_filtered_frame(self): | |
| self.assertTrue(pdf.empty) | ||
|
|
||
|
|
||
| @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") | ||
| class VectorizedUDFTests(ReusedPySparkTestCase): | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| ReusedPySparkTestCase.setUpClass() | ||
| cls.spark = SparkSession(cls.sc) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| ReusedPySparkTestCase.tearDownClass() | ||
| cls.spark.stop() | ||
|
|
||
| def test_vectorized_udf_basic(self): | ||
| df = self.spark.range(10).select( | ||
| col('id').cast('string').alias('str'), | ||
| col('id').cast('int').alias('int'), | ||
| col('id').alias('long'), | ||
| col('id').cast('float').alias('float'), | ||
| col('id').cast('double').alias('double'), | ||
| col('id').cast('boolean').alias('bool')) | ||
| f = lambda x: x | ||
| str_f = pandas_udf(f, StringType()) | ||
| int_f = pandas_udf(f, IntegerType()) | ||
| long_f = pandas_udf(f, LongType()) | ||
| float_f = pandas_udf(f, FloatType()) | ||
| double_f = pandas_udf(f, DoubleType()) | ||
| bool_f = pandas_udf(f, BooleanType()) | ||
| res = df.select(str_f(col('str')), int_f(col('int')), | ||
| long_f(col('long')), float_f(col('float')), | ||
| double_f(col('double')), bool_f(col('bool'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_boolean(self): | ||
| data = [(True,), (True,), (None,), (False,)] | ||
| schema = StructType().add("bool", BooleanType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| bool_f = pandas_udf(lambda x: x, BooleanType()) | ||
| res = df.select(bool_f(col('bool'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_byte(self): | ||
| data = [(None,), (2,), (3,), (4,)] | ||
| schema = StructType().add("byte", ByteType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| byte_f = pandas_udf(lambda x: x, ByteType()) | ||
| res = df.select(byte_f(col('byte'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_short(self): | ||
| data = [(None,), (2,), (3,), (4,)] | ||
| schema = StructType().add("short", ShortType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| short_f = pandas_udf(lambda x: x, ShortType()) | ||
| res = df.select(short_f(col('short'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_int(self): | ||
| data = [(None,), (2,), (3,), (4,)] | ||
| schema = StructType().add("int", IntegerType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| int_f = pandas_udf(lambda x: x, IntegerType()) | ||
| res = df.select(int_f(col('int'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_long(self): | ||
| data = [(None,), (2,), (3,), (4,)] | ||
| schema = StructType().add("long", LongType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| long_f = pandas_udf(lambda x: x, LongType()) | ||
| res = df.select(long_f(col('long'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_float(self): | ||
| data = [(3.0,), (5.0,), (-1.0,), (None,)] | ||
| schema = StructType().add("float", FloatType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| float_f = pandas_udf(lambda x: x, FloatType()) | ||
| res = df.select(float_f(col('float'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_double(self): | ||
| data = [(3.0,), (5.0,), (-1.0,), (None,)] | ||
| schema = StructType().add("double", DoubleType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| double_f = pandas_udf(lambda x: x, DoubleType()) | ||
| res = df.select(double_f(col('double'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_null_string(self): | ||
| data = [("foo",), (None,), ("bar",), ("bar",)] | ||
| schema = StructType().add("str", StringType()) | ||
| df = self.spark.createDataFrame(data, schema) | ||
| str_f = pandas_udf(lambda x: x, StringType()) | ||
| res = df.select(str_f(col('str'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_zero_parameter(self): | ||
| import pandas as pd | ||
| df = self.spark.range(100000) | ||
| f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) | ||
| res = df.select(f0()) | ||
| self.assertEquals(df.select(lit(1)).collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_datatype_string(self): | ||
| import pandas as pd | ||
| df = self.spark.range(100000) | ||
| f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), "long") | ||
| res = df.select(f0()) | ||
| self.assertEquals(df.select(lit(1)).collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_complex(self): | ||
| df = self.spark.range(10).select( | ||
| col('id').cast('int').alias('a'), | ||
| col('id').cast('int').alias('b'), | ||
| col('id').cast('double').alias('c')) | ||
| add = pandas_udf(lambda x, y: x + y, IntegerType()) | ||
| power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) | ||
| mul = pandas_udf(lambda x, y: x * y, DoubleType()) | ||
| res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) | ||
| expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) | ||
| self.assertEquals(expected.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_exception(self): | ||
| df = self.spark.range(10) | ||
| raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): | ||
| df.select(raise_exception(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_invalid_length(self): | ||
| import pandas as pd | ||
| df = self.spark.range(10) | ||
| raise_exception = pandas_udf(lambda size: pd.Series(1), LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'The length of returned value should be the same as input value'): | ||
| df.select(raise_exception()).collect() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add a test for mixing udf and vectorized udf?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll add a test. |
||
|
|
||
| def test_vectorized_udf_mix_udf(self): | ||
| from pyspark.sql.functions import udf | ||
| df = self.spark.range(10) | ||
| row_by_row_udf = udf(lambda x: x, LongType()) | ||
| pd_udf = pandas_udf(lambda x: x, LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'cannot mix vectorized udf and normal udf'): | ||
| df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArrowVectorizedSerializer?