@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
2323import net .razorvine .pickle .{Pickler , Unpickler }
2424
2525import org .apache .spark .TaskContext
26- import org .apache .spark .api .python .{ChainedPythonFunctions , PythonFunction , PythonRunner }
26+ import org .apache .spark .api .python .{ChainedPythonFunctions , PythonRunner }
2727import org .apache .spark .rdd .RDD
2828import org .apache .spark .sql .catalyst .InternalRow
2929import org .apache .spark .sql .catalyst .expressions ._
@@ -72,8 +72,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
7272
7373 val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
7474
75- // Most of the inputs are primitives, do not use memo for better performance
76- val pickle = new Pickler (false )
7775 // flatten all the arguments
7876 val allInputs = new ArrayBuffer [Expression ]
7977 val dataTypes = new ArrayBuffer [DataType ]
@@ -89,21 +87,30 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
8987 }.toArray
9088 }.toArray
9189 val projection = newMutableProjection(allInputs, child.output)()
90+ val schema = StructType (dataTypes.map(dt => StructField (" " , dt)))
91+ val needConversion = dataTypes.exists(EvaluatePython .needConversionInPython)
9292
93+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
94+ val pickle = new Pickler (needConversion)
9395 // Input iterator to Python: input rows are grouped so we send them in batches to Python.
9496 // For each row, add it to the queue.
9597 val inputIterator = iter.grouped(100 ).map { inputRows =>
9698 val toBePickled = inputRows.map { inputRow =>
9799 queue.add(inputRow)
98100 val row = projection(inputRow)
99- val fields = new Array [Any ](row.numFields)
100- var i = 0
101- while (i < row.numFields) {
102- val dt = dataTypes(i)
103- fields(i) = EvaluatePython .toJava(row.get(i, dt), dt)
104- i += 1
101+ if (needConversion) {
102+ EvaluatePython .toJava(row, schema)
103+ } else {
104+ // fast path for these types that does not need conversion in Python
105+ val fields = new Array [Any ](row.numFields)
106+ var i = 0
107+ while (i < row.numFields) {
108+ val dt = dataTypes(i)
109+ fields(i) = EvaluatePython .toJava(row.get(i, dt), dt)
110+ i += 1
111+ }
112+ fields
105113 }
106- fields
107114 }.toArray
108115 pickle.dumps(toBePickled)
109116 }
0 commit comments