-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f6b7373
8e6e5bc
8dc1adf
dd71ba9
8597bba
72a5ec0
876f9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,7 @@ private[spark] class PythonRDD( | |
| val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) | ||
|
|
||
| override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { | ||
| val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) | ||
| val runner = PythonRunner(func, bufferSize, reuse_worker) | ||
| runner.compute(firstParent.iterator(split, context), split.index, context) | ||
| } | ||
| } | ||
|
|
@@ -77,22 +77,30 @@ private[spark] case class PythonFunction( | |
| broadcastVars: JList[Broadcast[PythonBroadcast]], | ||
| accumulator: Accumulator[JList[Array[Byte]]]) | ||
|
|
||
|
|
||
| object PythonRunner { | ||
| def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { | ||
| new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A helper class to run Python UDFs in Spark. | ||
| * A helper class to run Python mapPartition/UDFs in Spark. | ||
| */ | ||
| private[spark] class PythonRunner( | ||
| funcs: Seq[PythonFunction], | ||
| funcs: Seq[Seq[PythonFunction]], | ||
|
||
| bufferSize: Int, | ||
| reuse_worker: Boolean, | ||
| rowBased: Boolean) | ||
| isUDF: Boolean, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, do you mind adding scaldoc for these two new parameters? |
||
| numArgs: Seq[Int]) | ||
| extends Logging { | ||
|
|
||
| // All the Python functions should have the same exec, version and envvars. | ||
| private val envVars = funcs.head.envVars | ||
| private val pythonExec = funcs.head.pythonExec | ||
| private val pythonVer = funcs.head.pythonVer | ||
| private val envVars = funcs.head.head.envVars | ||
| private val pythonExec = funcs.head.head.pythonExec | ||
| private val pythonVer = funcs.head.head.pythonVer | ||
|
|
||
| private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF | ||
| private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF | ||
|
|
||
| def compute( | ||
| inputIterator: Iterator[_], | ||
|
|
@@ -232,8 +240,8 @@ private[spark] class PythonRunner( | |
|
|
||
| @volatile private var _exception: Exception = null | ||
|
|
||
| private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet | ||
| private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) | ||
| private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet | ||
| private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala)) | ||
|
|
||
| setDaemon(true) | ||
|
|
||
|
|
@@ -284,11 +292,22 @@ private[spark] class PythonRunner( | |
| } | ||
| dataOut.flush() | ||
| // Serialized command: | ||
| dataOut.writeInt(if (rowBased) 1 else 0) | ||
| dataOut.writeInt(funcs.length) | ||
| funcs.foreach { f => | ||
| dataOut.writeInt(f.command.length) | ||
| dataOut.write(f.command) | ||
| if (isUDF) { | ||
| dataOut.writeInt(1) | ||
| dataOut.writeInt(funcs.length) | ||
| funcs.zip(numArgs).foreach { case (fs, numArg) => | ||
|
||
| dataOut.writeInt(numArg) | ||
| dataOut.writeInt(fs.length) | ||
| fs.foreach { f => | ||
| dataOut.writeInt(f.command.length) | ||
| dataOut.write(f.command) | ||
| } | ||
| } | ||
| } else { | ||
| dataOut.writeInt(0) | ||
| val command = funcs.head.head.command | ||
| dataOut.writeInt(command.length) | ||
| dataOut.write(command) | ||
| } | ||
| // Data values | ||
| PythonRDD.writeIteratorToStream(inputIterator, dataOut) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ | |
| from pyspark.broadcast import Broadcast, _broadcastRegistry | ||
| from pyspark.files import SparkFiles | ||
| from pyspark.serializers import write_with_length, write_int, read_long, \ | ||
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer | ||
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, AutoBatchedSerializer | ||
| from pyspark import shuffle | ||
|
|
||
| pickleSer = PickleSerializer() | ||
|
|
@@ -59,7 +59,49 @@ def read_command(serializer, file): | |
|
|
||
| def chain(f, g): | ||
| """chain two function together """ | ||
| return lambda x: g(f(x)) | ||
| return lambda *a: g(f(*a)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woah, didn't know that you could do varargs lambdas. Cool! |
||
|
|
||
|
|
||
| def wrap_udf(f, return_type): | ||
| return lambda *a: return_type.toInternal(f(*a)) | ||
|
|
||
|
|
||
| def read_single_udf(pickleSer, infile): | ||
| num_arg = read_int(infile) | ||
| row_func = None | ||
| for i in range(read_int(infile)): | ||
| f, return_type = read_command(pickleSer, infile) | ||
| if row_func is None: | ||
| row_func = f | ||
| else: | ||
| row_func = chain(row_func, f) | ||
| # the last returnType will be the return type of UDF | ||
| return num_arg, wrap_udf(row_func, return_type) | ||
|
|
||
|
|
||
| def read_udfs(pickleSer, infile): | ||
| num_udfs = read_int(infile) | ||
| udfs = [] | ||
|
||
| offset = 0 | ||
| for i in range(num_udfs): | ||
| num_arg, udf = read_single_udf(pickleSer, infile) | ||
| udfs.append((offset, offset + num_arg, udf)) | ||
| offset += num_arg | ||
|
|
||
| if num_udfs == 1: | ||
| udf = udfs[0][2] | ||
|
|
||
| # fast path for single UDF | ||
| def mapper(args): | ||
|
||
| return udf(*args) | ||
| else: | ||
| def mapper(args): | ||
| return tuple(udf(*args[start:end]) for start, end, udf in udfs) | ||
|
|
||
| func = lambda _, it: map(mapper, it) | ||
| ser = AutoBatchedSerializer(PickleSerializer()) | ||
| # profiling is not supported for UDF | ||
| return func, None, ser, ser | ||
|
|
||
|
|
||
| def main(infile, outfile): | ||
|
|
@@ -107,21 +149,10 @@ def main(infile, outfile): | |
| _broadcastRegistry.pop(bid) | ||
|
|
||
| _accumulatorRegistry.clear() | ||
| row_based = read_int(infile) | ||
| num_commands = read_int(infile) | ||
| if row_based: | ||
| profiler = None # profiling is not supported for UDF | ||
| row_func = None | ||
| for i in range(num_commands): | ||
| f, returnType, deserializer = read_command(pickleSer, infile) | ||
| if row_func is None: | ||
| row_func = f | ||
| else: | ||
| row_func = chain(row_func, f) | ||
| serializer = deserializer | ||
| func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) | ||
| is_udf = read_int(infile) | ||
| if is_udf: | ||
|
||
| func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) | ||
| else: | ||
| assert num_commands == 1 | ||
| func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
|
||
| init_time = time.time() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{StructField, StructType} | |
| * we drain the queue to find the original input row. Note that if the Python process is way too | ||
| * slow, this could lead to the queue growing unbounded and eventually run out of memory. | ||
| */ | ||
| case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) | ||
| case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
| extends SparkPlan { | ||
|
|
||
| def children: Seq[SparkPlan] = child :: Nil | ||
|
|
@@ -69,11 +69,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
| // combine input with output from Python. | ||
| val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() | ||
|
|
||
| val (pyFuncs, children) = collectFunctions(udf) | ||
| val (pyFuncs, children) = udfs.map(collectFunctions).unzip | ||
| val numArgs = children.map(_.length) | ||
|
|
||
| val pickle = new Pickler | ||
| val currentRow = newMutableProjection(children, child.output)() | ||
| val fields = children.map(_.dataType) | ||
| // flatten all the arguments | ||
| val allChildren = children.flatMap(x => x) | ||
|
||
| val currentRow = newMutableProjection(allChildren, child.output)() | ||
| val fields = allChildren.map(_.dataType) | ||
| val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) | ||
|
|
||
| // Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
|
|
@@ -89,19 +92,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
| val context = TaskContext.get() | ||
|
|
||
| // Output iterator for results from Python. | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs) | ||
| .compute(inputIterator, context.partitionId(), context) | ||
|
|
||
| val unpickle = new Unpickler | ||
| val row = new GenericMutableRow(1) | ||
| val mutableRow = new GenericMutableRow(1) | ||
| val joined = new JoinedRow | ||
| val resultType = if (udfs.length == 1) { | ||
| udfs.head.dataType | ||
| } else { | ||
| StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) | ||
| } | ||
| val resultProj = UnsafeProjection.create(output, output) | ||
|
|
||
| outputIterator.flatMap { pickedResult => | ||
| val unpickledBatch = unpickle.loads(pickedResult) | ||
| unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala | ||
| }.map { result => | ||
| row(0) = EvaluatePython.fromJava(result, udf.dataType) | ||
| val row = if (udfs.length == 1) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than evaluating this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you do this, you could reduce the scope of the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. |
||
| // fast path for single UDF | ||
| mutableRow(0) = EvaluatePython.fromJava(result, resultType) | ||
| mutableRow | ||
| } else { | ||
| EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] | ||
| } | ||
| resultProj(joined(queue.poll(), row)) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
private[spark].