Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ private[spark] object PythonRunner {
}
}

/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_VECTORIZED_UDF = 2
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
Expand Down Expand Up @@ -310,7 +319,7 @@ private[spark] class PythonRunner(
dataOut.flush()
// Serialized command:
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(PythonEvalType.SQL_BATCHED_UDF)
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
Expand All @@ -324,7 +333,7 @@ private[spark] class PythonRunner(
}
}
} else {
dataOut.writeInt(0)
dataOut.writeInt(PythonEvalType.NON_UDF)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
Expand Down Expand Up @@ -382,7 +391,8 @@ private[spark] class PythonRunner(
}

/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause)
private[spark] class PythonException(msg: String, cause: Exception)
extends RuntimeException(msg, cause)

/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
Expand All @@ -399,7 +409,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}

private object SpecialLengths {
private[spark] object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,39 @@ def __repr__(self):
return "UTF8Deserializer(%s)" % self.use_unicode


class VectorizedSerializer(Serializer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrowVectorizedSerializer?


"""
(De)serializes a vectorized(Apache Arrow) stream.
"""

def load_stream(self, stream):
import pyarrow as pa
reader = pa.open_stream(stream)
for batch in reader:
vectors = [batch[col].to_pandas() for col in xrange(batch.num_columns)]
yield [batch.num_rows] + vectors

def dump_stream(self, iterator, stream):
import pandas as pd
import pyarrow as pa
# the schema is set at worker.py#read_vectorized_udfs
writer = pa.RecordBatchStreamWriter(stream, self.schema)
names = self.schema.names
types = [f.type for f in self.schema]
# todo: verify the type of the arrays returned by UDF.
try:
for arrays in iterator:
vectors = [pa.Array.from_pandas(array.astype(t.to_pandas_dtype(), copy=False),
mask=pd.isnull(array), type=t)
for array, t in zip(arrays, types)]
batch = pa.RecordBatch.from_arrays(vectors, names)
writer.write_batch(batch)
finally:
# todo: does arrow close the socket?
writer.close()


def read_long(stream):
length = stream.read(8)
if not length:
Expand Down
74 changes: 66 additions & 8 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
if sys.version < "3":
from itertools import imap as map

_have_pandas = False
try:
import pandas
_have_pandas = True
except:
pass

_have_arrow = False
try:
import pyarrow
_have_arrow = True
except:
pass

from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
Expand Down Expand Up @@ -2032,7 +2046,7 @@ class UserDefinedFunction(object):

.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None):
def __init__(self, func, returnType, name=None, vectorized=False):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2046,6 +2060,7 @@ def __init__(self, func, returnType, name=None):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self._vectorized = vectorized

@property
def returnType(self):
Expand Down Expand Up @@ -2077,7 +2092,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt)
self._name, wrapped_func, jdt, self._vectorized)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2111,6 +2126,53 @@ def wrapper(*args):
return wrapper


def _udf(f, returnType, vectorized):
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()


if _have_pandas and _have_arrow:

@since(2.3)
def pandas_udf(f=None, returnType=StringType()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hiding pandas_udf when no pandas and arrow installed, should we throw a exception if users without pandas and arrow try to use it?

"""
Creates a :class:`Column` expression representing a vectorized user defined function (UDF).

.. note:: The vectorized user-defined functions must be deterministic. Due to optimization,
duplicate invocations may be eliminated or the function may even be invoked more times
than it is present in the query.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we explain more about what the vectorized UDF is and its expected input parameters and outputs?


:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object

>>> from pyspark.sql.types import LongType
>>> add = pandas_udf(lambda x, y: x + y, LongType())
>>> @pandas_udf(returnType=LongType())
... def mul(x, y):
... return x * y
...
>>> import pandas as pd
>>> ones = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType())

>>> df = spark.createDataFrame([(1, 2), (3, 4)], ("a", "b"))
>>> df.select(add("a", "b").alias("add(a, b)"), mul("a", "b"), ones().alias("ones")).show()
+---------+---------+----+
|add(a, b)|mul(a, b)|ones|
+---------+---------+----+
| 3| 2| 1|
| 7| 12| 1|
+---------+---------+----+
"""
# decorator @pandas_udf, @pandas_udf() or @pandas_udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type, vectorized=True)
else:
return _udf(f=f, returnType=returnType, vectorized=True)


@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
Expand Down Expand Up @@ -2142,18 +2204,14 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
udf_obj = UserDefinedFunction(f, returnType)
return udf_obj._wrapped()

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type)
return functools.partial(_udf, returnType=return_type, vectorized=False)
else:
return _udf(f=f, returnType=returnType)
return _udf(f=f, returnType=returnType, vectorized=False)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
Expand Down
156 changes: 154 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.functions import UserDefinedFunction, sha2, lit, col, expr
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException


_have_arrow = False
try:
import pyarrow
Expand All @@ -76,6 +75,9 @@
# No Arrow, but that's okay, we'll skip those tests
pass

if _have_pandas and _have_arrow:
from pyspark.sql.functions import pandas_udf


class UTCOffsetTimezone(datetime.tzinfo):
"""
Expand Down Expand Up @@ -3122,6 +3124,156 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def test_vectorized_udf_basic(self):
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('boolean').alias('bool'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
int_f = pandas_udf(f, IntegerType())
long_f = pandas_udf(f, LongType())
float_f = pandas_udf(f, FloatType())
double_f = pandas_udf(f, DoubleType())
bool_f = pandas_udf(f, BooleanType())
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_boolean(self):
data = [(True,), (True,), (None,), (False,)]
schema = StructType().add("bool", BooleanType())
df = self.spark.createDataFrame(data, schema)
bool_f = pandas_udf(lambda x: x, BooleanType())
res = df.select(bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_byte(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("byte", ByteType())
df = self.spark.createDataFrame(data, schema)
byte_f = pandas_udf(lambda x: x, ByteType())
res = df.select(byte_f(col('byte')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_short(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("short", ShortType())
df = self.spark.createDataFrame(data, schema)
short_f = pandas_udf(lambda x: x, ShortType())
res = df.select(short_f(col('short')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_int(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("int", IntegerType())
df = self.spark.createDataFrame(data, schema)
int_f = pandas_udf(lambda x: x, IntegerType())
res = df.select(int_f(col('int')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_long(self):
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("long", LongType())
df = self.spark.createDataFrame(data, schema)
long_f = pandas_udf(lambda x: x, LongType())
res = df.select(long_f(col('long')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_float(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("float", FloatType())
df = self.spark.createDataFrame(data, schema)
float_f = pandas_udf(lambda x: x, FloatType())
res = df.select(float_f(col('float')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_double(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("double", DoubleType())
df = self.spark.createDataFrame(data, schema)
double_f = pandas_udf(lambda x: x, DoubleType())
res = df.select(double_f(col('double')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_string(self):
data = [("foo",), (None,), ("bar",), ("bar",)]
schema = StructType().add("str", StringType())
df = self.spark.createDataFrame(data, schema)
str_f = pandas_udf(lambda x: x, StringType())
res = df.select(str_f(col('str')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_zero_parameter(self):
import pandas as pd
df = self.spark.range(100000)
f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType())
res = df.select(f0())
self.assertEquals(df.select(lit(1)).collect(), res.collect())

def test_vectorized_udf_datatype_string(self):
import pandas as pd
df = self.spark.range(100000)
f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), "long")
res = df.select(f0())
self.assertEquals(df.select(lit(1)).collect(), res.collect())

def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
col('id').cast('int').alias('b'),
col('id').cast('double').alias('c'))
add = pandas_udf(lambda x, y: x + y, IntegerType())
power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
mul = pandas_udf(lambda x, y: x * y, DoubleType())
res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
self.assertEquals(expected.collect(), res.collect())

def test_vectorized_udf_exception(self):
df = self.spark.range(10)
raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
df.select(raise_exception(col('id'))).collect()

def test_vectorized_udf_invalid_length(self):
import pandas as pd
df = self.spark.range(10)
raise_exception = pandas_udf(lambda size: pd.Series(1), LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'The length of returned value should be the same as input value'):
df.select(raise_exception()).collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a test for mixing udf and vectorized udf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add a test.


def test_vectorized_udf_mix_udf(self):
from pyspark.sql.functions import udf
df = self.spark.range(10)
row_by_row_udf = udf(lambda x: x, LongType())
pd_udf = pandas_udf(lambda x: x, LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'cannot mix vectorized udf and normal udf'):
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
Loading