From a2a3f8205394aa3d48a22972d00263897e0035b1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 23 Aug 2017 18:05:05 +0900 Subject: [PATCH 1/8] Introduce vectorized UDF in Python. --- .../apache/spark/api/python/PythonRDD.scala | 18 +- python/pyspark/serializers.py | 33 ++ python/pyspark/sql/functions.py | 47 ++- python/pyspark/sql/tests.py | 126 ++++++- python/pyspark/sql/types.py | 29 ++ python/pyspark/worker.py | 65 +++- .../python/BatchEvalPythonExec.scala | 199 +++++++---- .../sql/execution/python/PythonUDF.scala | 3 +- .../python/UserDefinedPythonFunction.scala | 5 +- .../python/VectorizedPythonRunner.scala | 329 ++++++++++++++++++ 10 files changed, 756 insertions(+), 98 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 33771011fe36..b612c119ed8d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -90,6 +90,15 @@ private[spark] object PythonRunner { } } +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_VECTORIZED_UDF = 2 +} + /** * A helper class to run Python mapPartition/UDFs in Spark. * @@ -310,7 +319,7 @@ private[spark] class PythonRunner( dataOut.flush() // Serialized command: if (isUDF) { - dataOut.writeInt(1) + dataOut.writeInt(PythonEvalType.SQL_BATCHED_UDF) dataOut.writeInt(funcs.length) funcs.zip(argOffsets).foreach { case (chained, offsets) => dataOut.writeInt(offsets.length) @@ -324,7 +333,7 @@ private[spark] class PythonRunner( } } } else { - dataOut.writeInt(0) + dataOut.writeInt(PythonEvalType.NON_UDF) val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) dataOut.write(command) @@ -382,7 +391,8 @@ private[spark] class PythonRunner( } /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -399,7 +409,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { +private[spark] object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18..ef43778e5200 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -573,6 +573,39 @@ def __repr__(self): return "UTF8Deserializer(%s)" % self.use_unicode +class VectorizedSerializer(Serializer): + + """ + (De)serializes a vectorized(Apache Arrow) stream. + """ + + def load_stream(self, stream): + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + vectors = [batch[col].to_pandas() for col in xrange(batch.num_columns)] + yield [batch.num_rows] + vectors + + def dump_stream(self, iterator, stream): + import pandas as pd + import pyarrow as pa + # the schema is set at worker.py#read_vectorized_udfs + writer = pa.RecordBatchStreamWriter(stream, self.schema) + names = self.schema.names + types = [f.type for f in self.schema] + # todo: verify the type of the arrays returned by UDF. + try: + for arrays in iterator: + vectors = [pa.Array.from_pandas(array.astype(t.to_pandas_dtype(), copy=False), + mask=pd.isnull(array), type=t) + for array, t in zip(arrays, types)] + batch = pa.RecordBatch.from_arrays(vectors, names) + writer.write_batch(batch) + finally: + # todo: does arrow close the socket? + writer.close() + + def read_long(stream): length = stream.read(8) if not length: diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0e76182e0e02..77fc03c43b44 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2032,7 +2032,7 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None): + def __init__(self, func, returnType, name=None, vectorized=False): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2046,6 +2046,7 @@ def __init__(self, func, returnType, name=None): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + self._vectorized = vectorized @property def returnType(self): @@ -2077,7 +2078,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt) + self._name, wrapped_func, jdt, self._vectorized) return judf def __call__(self, *cols): @@ -2111,8 +2112,40 @@ def wrapper(*args): return wrapper +@since(2.3) +def pandas_udf(f=None, returnType=None): + """Creates a :class:`Column` expression representing a vectorized user defined function (UDF). + + .. note:: The vectorized user-defined functions must be deterministic. Due to optimization, + duplicate invocations may be eliminated or the function may even be invoked more times than + it is present in the query. + + :param f: python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import LongType + >>> add = pandas_udf(lambda x, y: x + y, LongType()) + >>> @pandas_udf(returnType=LongType()) + ... def mul(x, y): + ... return x * y + ... + >>> import pandas as pd + >>> ones = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) + + >>> df = spark.createDataFrame([(1, 2), (3, 4)], ("a", "b")) + >>> df.select(add("a", "b").alias("add(a, b)"), mul("a", "b"), ones().alias("ones")).show() + +---------+---------+----+ + |add(a, b)|mul(a, b)|ones| + +---------+---------+----+ + | 3| 2| 1| + | 7| 12| 1| + +---------+---------+----+ + """ + return udf(f, returnType, vectorized=True) + + @since(1.3) -def udf(f=None, returnType=StringType()): +def udf(f=None, returnType=StringType(), vectorized=False): """Creates a :class:`Column` expression representing a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, @@ -2142,8 +2175,8 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - def _udf(f, returnType=StringType()): - udf_obj = UserDefinedFunction(f, returnType) + def _udf(f, returnType=StringType(), vectorized=False): + udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) @@ -2151,9 +2184,9 @@ def _udf(f, returnType=StringType()): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type) + return functools.partial(_udf, returnType=return_type, vectorized=vectorized) else: - return _udf(f=f, returnType=returnType) + return _udf(f=f, returnType=returnType, vectorized=vectorized) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3d87ccfc03dd..87134f74c87c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -63,7 +63,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2, lit +from pyspark.sql.functions import UserDefinedFunction, sha2, lit, col, expr, pandas_udf from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -3122,6 +3122,130 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class VectorizedUDFTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def test_vectorized_udf_basic(self): + df = self.spark.range(10).select( + col('id').cast('string').alias('str'), + col('id').cast('int').alias('int'), + col('id').alias('long'), + col('id').cast('float').alias('float'), + col('id').cast('double').alias('double'), + col('id').cast('boolean').alias('bool')) + f = lambda x: x + str_f = pandas_udf(f, StringType()) + int_f = pandas_udf(f, IntegerType()) + long_f = pandas_udf(f, LongType()) + float_f = pandas_udf(f, FloatType()) + double_f = pandas_udf(f, DoubleType()) + bool_f = pandas_udf(f, BooleanType()) + res = df.select(str_f(col('str')), int_f(col('int')), + long_f(col('long')), float_f(col('float')), + double_f(col('double')), bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_boolean(self): + data = [(True,), (True,), (None,), (False,)] + schema = StructType().add("bool", BooleanType()) + df = self.spark.createDataFrame(data, schema) + bool_f = pandas_udf(lambda x: x, BooleanType()) + res = df.select(bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_byte(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("byte", ByteType()) + df = self.spark.createDataFrame(data, schema) + byte_f = pandas_udf(lambda x: x, ByteType()) + res = df.select(byte_f(col('byte'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_short(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("short", ShortType()) + df = self.spark.createDataFrame(data, schema) + short_f = pandas_udf(lambda x: x, ShortType()) + res = df.select(short_f(col('short'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_int(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("int", IntegerType()) + df = self.spark.createDataFrame(data, schema) + int_f = pandas_udf(lambda x: x, IntegerType()) + res = df.select(int_f(col('int'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_long(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("long", LongType()) + df = self.spark.createDataFrame(data, schema) + long_f = pandas_udf(lambda x: x, LongType()) + res = df.select(long_f(col('long'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_float(self): + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("float", FloatType()) + df = self.spark.createDataFrame(data, schema) + float_f = pandas_udf(lambda x: x, FloatType()) + res = df.select(float_f(col('float'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_double(self): + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("double", DoubleType()) + df = self.spark.createDataFrame(data, schema) + double_f = pandas_udf(lambda x: x, DoubleType()) + res = df.select(double_f(col('double'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_string(self): + data = [("foo",), (None,), ("bar",), ("bar",)] + schema = StructType().add("str", StringType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, StringType()) + res = df.select(str_f(col('str'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_zero_parameter(self): + import pandas as pd + df = self.spark.range(100000) + f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) + res = df.select(f0()) + self.assertEquals(df.select(lit(1)).collect(), res.collect()) + + def test_vectorized_udf_complex(self): + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b'), + col('id').cast('double').alias('c')) + add = pandas_udf(lambda x, y: x + y, IntegerType()) + power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) + mul = pandas_udf(lambda x, y: x * y, DoubleType()) + res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) + expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) + self.assertEquals(expected.collect(), res.collect()) + + def test_vectorized_udf_exception(self): + df = self.spark.range(10) + raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): + df.select(raise_exception(col('id'))).collect() + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 51bf7bef4976..1d795b3b848a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1582,6 +1582,35 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) +def toArrowType(dt): + import pyarrow as pa + if type(dt) == BooleanType: + arrow_type = pa.bool_() + elif type(dt) == ByteType: + arrow_type = pa.int8() + elif type(dt) == ShortType: + arrow_type = pa.int16() + elif type(dt) == IntegerType: + arrow_type = pa.int32() + elif type(dt) == LongType: + arrow_type = pa.int64() + elif type(dt) == FloatType: + arrow_type = pa.float32() + elif type(dt) == DoubleType: + arrow_type = pa.float64() + elif type(dt) == DecimalType: + arrow_type = pa.decimal(dt.precision, dt.scale) + elif type(dt) == StringType: + arrow_type = pa.string() + return arrow_type + + +def toArrowSchema(types): + import pyarrow as pa + fields = [pa.field("c_" + str(i), toArrowType(types[i])) for i in range(len(types))] + return pa.schema(fields) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index baaa3fe074e9..d7d0afa1ef24 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -25,18 +25,29 @@ import socket import traceback +if sys.version < '3': + from itertools import imap as map + from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + BatchedSerializer, VectorizedSerializer from pyspark import shuffle +from pyspark.sql.types import toArrowSchema pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() +class PythonEvalType(object): + NON_UDF = 0 + SQL_BATCHED_UDF = 1 + SQL_VECTORIZED_UDF = 2 + + def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) @@ -71,18 +82,22 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def read_single_udf(pickleSer, infile): +def read_single_udf(pickleSer, infile, vectorized=False): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] - row_func = None + func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) - if row_func is None: - row_func = f + if func is None: + func = f else: - row_func = chain(row_func, f) - # the last returnType will be the return type of UDF - return arg_offsets, wrap_udf(row_func, return_type) + func = chain(func, f) + if (vectorized): + # todo: shall we do conversion depends on data type? + return arg_offsets, func, return_type + else: + # the last returnType will be the return type of UDF + return arg_offsets, wrap_udf(func, return_type) def read_udfs(pickleSer, infile): @@ -107,6 +122,34 @@ def read_udfs(pickleSer, infile): return func, None, ser, ser +def read_vectorized_udfs(pickleSer, infile): + num_udfs = read_int(infile) + udfs = {} + call_udfs = [] + types = [] + for i in range(num_udfs): + arg_offsets, udf, return_type = read_single_udf(pickleSer, infile, True) + types.append(return_type) + udfs['f%d' % i] = udf + # the first value of the inputs is the number of elements in this batch, and we only + # need it for 0-parameter UDF. + if arg_offsets: + args = ["a[%d]" % (o + 1) for o in arg_offsets] + else: + args = ["a[0]"] + call_udfs.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: [f0(a[0]), f1(a[1], a[2]), f2(a[3])] + mapper_str = "lambda a: [%s]" % (", ".join(call_udfs)) + mapper = eval(mapper_str, udfs) + + func = lambda _, it: map(mapper, it) + ser = VectorizedSerializer() + ser.schema = toArrowSchema(types) + # profiling is not supported for UDF + return func, None, ser, ser + + def main(infile, outfile): try: boot_time = time.time() @@ -159,8 +202,10 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - is_sql_udf = read_int(infile) - if is_sql_udf: + mode = read_int(infile) + if mode == PythonEvalType.SQL_VECTORIZED_UDF: + func, profiler, deserializer, serializer = read_vectorized_udfs(pickleSer, infile) + elif mode == PythonEvalType.SQL_BATCHED_UDF: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) else: func, profiler, deserializer, serializer = read_command(pickleSer, infile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 3e176e2cde5b..a14462f63d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -62,6 +62,7 @@ import org.apache.spark.util.Utils */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { + assert(udfs.map(_.vectorized).distinct.length == 1, "cannot mix vectorized udf and normal udf") def children: Seq[SparkPlan] = child :: Nil @@ -84,86 +85,138 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - TaskContext.get().addTaskCompletionListener({ ctx => - queue.close() - }) - - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - - // flatten all the arguments - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } + if (udfs.head.vectorized) { + inputRDD.mapPartitions { iter => + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener({ ctx => + queue.close() + }) + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray }.toArray - }.toArray - val projection = newMutableProjection(allInputs, child.output) - val schema = StructType(dataTypes.map(dt => StructField("", dt))) - val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) - - // enable memo iff we serialize the row with schema (schema and class should be memorized) - val pickle = new Pickler(needConversion) - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - val row = projection(inputRow) - if (needConversion) { - EvaluatePython.toJava(row, schema) - } else { - // fast path for these types that does not need conversion in Python - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 - } - fields - } - }.grouped(100).map(x => pickle.dumps(x.toArray)) + val projection = newMutableProjection(allInputs, child.output) + val schema = StructType(dataTypes.zipWithIndex.map { + case (dt, index) => StructField(s"c_$index", dt) + }) - val context = TaskContext.get() + // For each row, add it to the queue. - // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) - .compute(inputIterator, context.partitionId(), context) + val inputIterator = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + projection(inputRow) + } - val unpickle = new Unpickler - val mutableRow = new GenericInternalRow(1) - val joined = new JoinedRow - val resultType = if (udfs.length == 1) { - udfs.head.dataType - } else { - StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + // Output iterator for results from Python. + val outputIterator = + new VectorizedPythonRunner(pyFuncs, 10000, bufferSize, reuseWorker, argOffsets) + .compute(inputIterator, schema, context.partitionId(), context) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + outputIterator.map { row => + resultProj(joined(queue.remove(), row)) + } } - val resultProj = UnsafeProjection.create(output, output) - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - val row = if (udfs.length == 1) { - // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) - mutableRow + } else { + inputRDD.mapPartitions { iter => + val context = TaskContext.get() + + EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener({ ctx => + queue.close() + }) + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output) + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields + } + }.grouped(100).map(x => pickle.dumps(x.toArray)) + + // Output iterator for results from Python. + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) + .compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val mutableRow = new GenericInternalRow(1) + val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } + val resultProj = UnsafeProjection.create(output, output) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } + resultProj(joined(queue.remove(), row)) } - resultProj(joined(queue.remove(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 7ebbdb9846cc..067336d779c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -28,7 +28,8 @@ case class PythonUDF( name: String, func: PythonFunction, dataType: DataType, - children: Seq[Expression]) + children: Seq[Expression], + vectorized: Boolean = false) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 0d39c8ff980f..440b6cc9c192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.types.DataType case class UserDefinedPythonFunction( name: String, func: PythonFunction, - dataType: DataType) { + dataType: DataType, + vectorized: Boolean = false) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e) + PythonUDF(name, func, dataType, e, vectorized) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala new file mode 100644 index 000000000000..d935ae3c636f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark.{SparkEnv, SparkFiles, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonException, PythonRDD, SpecialLengths} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Similar to `PythonRunner`, but exchange data with Python worker via columnar format. + */ +class VectorizedPythonRunner( + funcs: Seq[ChainedPythonFunctions], + batchSize: Int, + bufferSize: Int, + reuse_worker: Boolean, + argOffsets: Array[Array[Int]]) extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator + + // todo: return column batch? + def compute( + inputRows: Iterator[InternalRow], + schema: StructType, + partitionIndex: Int, + context: TaskContext): Iterator[InternalRow] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + @volatile var released = false + + // Start a thread to feed the process input from our parent's iterator + val writerThread = new WriterThread( + env, worker, inputRows, schema, partitionIndex, context) + + context.addTaskCompletionListener { context => + writerThread.shutdownOnTaskCompletion() + if (!reuse_worker || !released) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + val reader = new ArrowStreamReader(stream, allocator) + + new Iterator[InternalRow] { + private val root = reader.getVectorSchemaRoot + private val vectors = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + + var closed = false + + context.addTaskCompletionListener { _ => + // todo: we need something like `read.end()`, which release all the resources, but leave + // the input stream open. `reader.close` will close the socket and we can't reuse worker. + // So here we simply not close the reader, which is problematic. + if (!closed) { + root.close() + allocator.close() + } + } + + private[this] var batchLoaded = true + private[this] var currentIter: Iterator[InternalRow] = Iterator.empty + + override def hasNext: Boolean = batchLoaded && (currentIter.hasNext || loadNextBatch()) || { + root.close() + allocator.close() + closed = true + false + } + + private def loadNextBatch(): Boolean = { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + batch.setNumRows(root.getRowCount) + currentIter = batch.rowIterator().asScala + } else { + // end of arrow batches, handle some special signal + val signal = stream.readInt() + if (signal == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + assert(signal == SpecialLengths.TIMING_DATA) + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + + assert(stream.readInt() == SpecialLengths.END_OF_DATA_SECTION) + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released = true + } + } + } + hasNext // skip empty batches if any + } + + override def next(): InternalRow = { + currentIter.next() + } + } + } + + class WriterThread( + env: SparkEnv, + worker: Socket, + inputRows: Iterator[InternalRow], + schema: StructType, + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + + // SQL_VECTORIZED_UDF means arrow mode + dataOut.writeInt(PythonEvalType.SQL_VECTORIZED_UDF) + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach(dataOut.writeInt) + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + dataOut.flush() + + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + var closed = false + + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() + } + } + + // TODO: does ArrowStreamWriter buffer data? + // TODO: who decides the dictionary? + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + Utils.tryWithSafeFinally { + while (inputRows.hasNext) { + var rowCount = 0 + while (inputRows.hasNext && rowCount < batchSize) { + val row = inputRows.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } { + writer.end() + root.close() + allocator.close() + closed = true + } + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } +} From a1e4f62c993cc4c35523ed947178a72a2aadb753 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 6 Sep 2017 13:46:21 +0900 Subject: [PATCH 2/8] Add check if the length of returned value is the same as input value. --- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/worker.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 87134f74c87c..6251ea3680e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3245,6 +3245,15 @@ def test_vectorized_udf_exception(self): with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): df.select(raise_exception(col('id'))).collect() + def test_vectorized_udf_invalid_length(self): + import pandas as pd + df = self.spark.range(10) + raise_exception = pandas_udf(lambda size: pd.Series(1), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, + 'The length of returned value should be the same as input value'): + df.select(raise_exception()).collect() + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d7d0afa1ef24..c71f4e7fa160 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -138,9 +138,16 @@ def read_vectorized_udfs(pickleSer, infile): else: args = ["a[0]"] call_udfs.append("f%d(%s)" % (i, ", ".join(args))) + def chk_len(v, size): + if len(v) == size: + return v + else: + raise Exception("The length of returned value should be the same as input value") + call_and_chk_len = ['chk_len(%s, a[0])' % call_udf for call_udf in call_udfs] + udfs['chk_len'] = chk_len # Create function like this: - # lambda a: [f0(a[0]), f1(a[1], a[2]), f2(a[3])] - mapper_str = "lambda a: [%s]" % (", ".join(call_udfs)) + # lambda a: [chk_len(f0(a[0]), a[0]), chk_len(f1(a[1], a[2]), a[0]), ...] + mapper_str = "lambda a: [%s]" % (", ".join(call_and_chk_len)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) From 84d2767857fd5287665fde10ae465dc11dd241f5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 6 Sep 2017 21:26:37 +0900 Subject: [PATCH 3/8] Fix style. --- python/pyspark/sql/tests.py | 3 ++- python/pyspark/worker.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6251ea3680e8..214a1604d5dd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3250,7 +3250,8 @@ def test_vectorized_udf_invalid_length(self): df = self.spark.range(10) raise_exception = pandas_udf(lambda size: pd.Series(1), LongType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, + with self.assertRaisesRegexp( + Exception, 'The length of returned value should be the same as input value'): df.select(raise_exception()).collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c71f4e7fa160..aee10f129f6c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -138,11 +138,13 @@ def read_vectorized_udfs(pickleSer, infile): else: args = ["a[0]"] call_udfs.append("f%d(%s)" % (i, ", ".join(args))) + def chk_len(v, size): if len(v) == size: return v else: raise Exception("The length of returned value should be the same as input value") + call_and_chk_len = ['chk_len(%s, a[0])' % call_udf for call_udf in call_udfs] udfs['chk_len'] = chk_len # Create function like this: From 1db6cb5309815f3177aebe18047c5137e990abf2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 7 Sep 2017 21:16:33 +0900 Subject: [PATCH 4/8] Check if pandas is installed or not. --- python/pyspark/sql/functions.py | 87 ++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 77fc03c43b44..bed0d0530a82 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,6 +25,14 @@ if sys.version < "3": from itertools import imap as map +_have_pandas = False +try: + import pandas + _have_pandas = True +except: + # No Pandas, but that's okay, we'll skip those tests + pass + from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -2112,40 +2120,55 @@ def wrapper(*args): return wrapper -@since(2.3) -def pandas_udf(f=None, returnType=None): - """Creates a :class:`Column` expression representing a vectorized user defined function (UDF). - - .. note:: The vectorized user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. +def _udf(f, returnType, vectorized): + udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + return udf_obj._wrapped() - :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import LongType - >>> add = pandas_udf(lambda x, y: x + y, LongType()) - >>> @pandas_udf(returnType=LongType()) - ... def mul(x, y): - ... return x * y - ... - >>> import pandas as pd - >>> ones = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) +if _have_pandas: - >>> df = spark.createDataFrame([(1, 2), (3, 4)], ("a", "b")) - >>> df.select(add("a", "b").alias("add(a, b)"), mul("a", "b"), ones().alias("ones")).show() - +---------+---------+----+ - |add(a, b)|mul(a, b)|ones| - +---------+---------+----+ - | 3| 2| 1| - | 7| 12| 1| - +---------+---------+----+ - """ - return udf(f, returnType, vectorized=True) + @since(2.3) + def pandas_udf(f=None, returnType=StringType()): + """ + Creates a :class:`Column` expression representing a vectorized user defined function (UDF). + + .. note:: The vectorized user-defined functions must be deterministic. Due to optimization, + duplicate invocations may be eliminated or the function may even be invoked more times + than it is present in the query. + + :param f: python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import LongType + >>> add = pandas_udf(lambda x, y: x + y, LongType()) + >>> @pandas_udf(returnType=LongType()) + ... def mul(x, y): + ... return x * y + ... + >>> import pandas as pd + >>> ones = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType()) + + >>> df = spark.createDataFrame([(1, 2), (3, 4)], ("a", "b")) + >>> df.select(add("a", "b").alias("add(a, b)"), mul("a", "b"), ones().alias("ones")).show() + +---------+---------+----+ + |add(a, b)|mul(a, b)|ones| + +---------+---------+----+ + | 3| 2| 1| + | 7| 12| 1| + +---------+---------+----+ + """ + # decorator @pandas_udf, @pandas_udf() or @pandas_udf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_udf, returnType=return_type, vectorized=True) + else: + return _udf(f=f, returnType=returnType, vectorized=True) @since(1.3) -def udf(f=None, returnType=StringType(), vectorized=False): +def udf(f=None, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, @@ -2175,18 +2198,14 @@ def udf(f=None, returnType=StringType(), vectorized=False): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - def _udf(f, returnType=StringType(), vectorized=False): - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) - return udf_obj._wrapped() - # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, vectorized=False) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, vectorized=False) blacklist = ['map', 'since', 'ignore_unicode_prefix'] From 3a0d4a6102b6c0180e90f701d87693617603ca29 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 7 Sep 2017 21:16:52 +0900 Subject: [PATCH 5/8] Add a test using datatype in string form. --- python/pyspark/sql/tests.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 214a1604d5dd..a829d6154f8f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3226,6 +3226,13 @@ def test_vectorized_udf_zero_parameter(self): res = df.select(f0()) self.assertEquals(df.select(lit(1)).collect(), res.collect()) + def test_vectorized_udf_datatype_string(self): + import pandas as pd + df = self.spark.range(100000) + f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), "long") + res = df.select(f0()) + self.assertEquals(df.select(lit(1)).collect(), res.collect()) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), From 2f929d8e0ec01ca7070fc0969e5091dad4ce8350 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 7 Sep 2017 23:50:40 +0900 Subject: [PATCH 6/8] Fix tests. --- python/pyspark/sql/tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a829d6154f8f..884532c5655d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -63,10 +63,12 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2, lit, col, expr, pandas_udf +from pyspark.sql.functions import UserDefinedFunction, sha2, lit, col, expr from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +if _have_pandas: + from pyspark.sql.functions import pandas_udf _have_arrow = False try: From dbc6dd2138b427d8436eb0d8bdc4ba134f254e35 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Sep 2017 13:54:44 +0900 Subject: [PATCH 7/8] Fix tests. --- python/pyspark/sql/functions.py | 10 ++++++++-- python/pyspark/sql/tests.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bed0d0530a82..91d7863b37c4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -30,7 +30,13 @@ import pandas _have_pandas = True except: - # No Pandas, but that's okay, we'll skip those tests + pass + +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: pass from pyspark import since, SparkContext @@ -2125,7 +2131,7 @@ def _udf(f, returnType, vectorized): return udf_obj._wrapped() -if _have_pandas: +if _have_pandas and _have_arrow: @since(2.3) def pandas_udf(f=None, returnType=StringType()): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 884532c5655d..1e163c459fce 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -67,9 +67,6 @@ from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -if _have_pandas: - from pyspark.sql.functions import pandas_udf - _have_arrow = False try: import pyarrow @@ -78,6 +75,9 @@ # No Arrow, but that's okay, we'll skip those tests pass +if _have_pandas and _have_arrow: + from pyspark.sql.functions import pandas_udf + class UTCOffsetTimezone(datetime.tzinfo): """ @@ -3124,7 +3124,7 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) -@unittest.skipIf(not _have_arrow, "Arrow not installed") +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedPySparkTestCase): @classmethod From 803054e9f30b057e1b194b5625cb6216d865f1d4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Sep 2017 15:16:41 +0900 Subject: [PATCH 8/8] Add a test for mixing udf and vectorized udf. --- python/pyspark/sql/tests.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1e163c459fce..7efde7b07064 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3264,6 +3264,15 @@ def test_vectorized_udf_invalid_length(self): 'The length of returned value should be the same as input value'): df.select(raise_exception()).collect() + def test_vectorized_udf_mix_udf(self): + from pyspark.sql.functions import udf + df = self.spark.range(10) + row_by_row_udf = udf(lambda x: x, LongType()) + pd_udf = pandas_udf(lambda x: x, LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'cannot mix vectorized udf and normal udf'): + df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() + if __name__ == "__main__": from pyspark.sql.tests import *