Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -77,22 +77,30 @@ private[spark] case class PythonFunction(
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])


object PythonRunner {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be private[spark].

def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1))
}
}

/**
* A helper class to run Python UDFs in Spark.
* A helper class to run Python mapPartition/UDFs in Spark.
*/
private[spark] class PythonRunner(
funcs: Seq[PythonFunction],
funcs: Seq[Seq[PythonFunction]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type is a little strange, so do you mind adding a scaladoc comment to explain what the two levels of nesting correspond to?

bufferSize: Int,
reuse_worker: Boolean,
rowBased: Boolean)
isUDF: Boolean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, do you mind adding scaldoc for these two new parameters?

numArgs: Seq[Int])
extends Logging {

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer
private val envVars = funcs.head.head.envVars
private val pythonExec = funcs.head.head.pythonExec
private val pythonVer = funcs.head.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -232,8 +240,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -284,11 +292,22 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(numArgs).foreach { case (fs, numArg) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since correctness relies on funcs.length == numArgs.length, do you mind adding a require at the start of the constructor to enforce this?

dataOut.writeInt(numArg)
dataOut.writeInt(fs.length)
fs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
Expand All @@ -314,6 +314,16 @@ def test_chained_python_udf(self):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
self.assertEqual(tuple(row), (4, 12))
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
63 changes: 47 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, AutoBatchedSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -59,7 +59,49 @@ def read_command(serializer, file):

def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))
return lambda *a: g(f(*a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah, didn't know that you could do varargs lambdas. Cool!



def wrap_udf(f, return_type):
return lambda *a: return_type.toInternal(f(*a))


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return num_arg, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
udfs = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like udfs holds (something, something, udf) triples. Mind adding a line-comment here to say what the first two components of the tuple correspond to?

offset = 0
for i in range(num_udfs):
num_arg, udf = read_single_udf(pickleSer, infile)
udfs.append((offset, offset + num_arg, udf))
offset += num_arg

if num_udfs == 1:
udf = udfs[0][2]

# fast path for single UDF
def mapper(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet you could even do mapper = udf if you wanted to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't, input of mapper is a tuple, but udf is not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. Makes sense.

return udf(*args)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)

func = lambda _, it: map(mapper, it)
ser = AutoBatchedSerializer(PickleSerializer())
# profiling is not supported for UDF
return func, None, ser, ser


def main(infile, outfile):
Expand Down Expand Up @@ -107,21 +149,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
is_udf = read_int(infile)
if is_udf:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd maybe call this is_sql_udf just to make it clearer that this is part of PySpark SQL support, but I don't feel strongly about this.

func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case e @ python.EvaluatePython(udfs, child, _) =>
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil
Expand Down Expand Up @@ -69,11 +69,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)
val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)

val pickle = new Pickler
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
// flatten all the arguments
val allChildren = children.flatMap(x => x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick clarification: if I have a function like select udf(x), udf2(x), udf3(x), udf4(x) from ..., we'll send the x column's value four times to PySpark? I know that we have a conceptually similar problem when we're evaluating multiple aggregates in parallel in JVM Spark SQL, but in that case I think we only project each column once and end up rebinding the references / offsets to reference the single copy.

My hunch is that this extra copy isn't a huge perf. issue compared to the slow multiple-Python-UDF evaluation strategy we were using before, so I think it's fine to leave this for now. If it does become a problem, we could optimize later.

val currentRow = newMutableProjection(allChildren, child.output)()
val fields = allChildren.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
Expand All @@ -89,19 +92,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)

outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
val row = if (udfs.length == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than evaluating this if condition for every row, could we lift this out of the map and perform it once while building the RDD DAG? i.e. assign the result of line 108 to a variable and have the if be the last return value of this block?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do this, you could reduce the scope of the mutableRow created up on line 99, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
resultProj(joined(queue.poll(), row))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,28 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
* Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
*/
case class EvaluatePython(
udf: PythonUDF,
udfs: Seq[PythonUDF],
child: LogicalPlan,
resultAttribute: AttributeReference)
resultAttribute: Seq[AttributeReference])
extends logical.UnaryNode {

def output: Seq[Attribute] = child.output :+ resultAttribute
def output: Seq[Attribute] = child.output ++ resultAttribute

// References should not include the produced attribute.
override def references: AttributeSet = udf.references
override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
}


object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
new EvaluatePython(udfs, child, resultAttrs)
}

def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
Expand Down
Loading