diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index fcf68467460b..b44b13c8de0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils /** @@ -42,8 +43,8 @@ class ArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BaseArrowPythonRunner[Iterator[InternalRow]]( - funcs, evalType, argOffsets) { + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) + with PythonArrowOutput { override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 8ea9881c575a..25ce16db264a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -17,27 +17,27 @@ package org.apache.spark.sql.execution.python -import java.io._ -import java.net._ +import java.io.DataOutputStream +import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter -import org.apache.spark._ -import org.apache.spark.api.python._ +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils /** - * Python UDF Runner for cogrouped udfs. Although the data is exchanged with the python - * worker via arrow, we cannot use `ArrowPythonRunner` as we need to send more than one - * dataframe. + * Python UDF Runner for cogrouped udfs. It sends Arrow bathes from two different DataFrames, + * groups them in Python, and receive it back in JVM as batches of single DataFrame. */ -class CogroupedArrowPythonRunner( +class CoGroupedArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], @@ -45,8 +45,9 @@ class CogroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])]( - funcs, evalType, argOffsets) { + extends BasePythonRunner[ + (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) + with PythonArrowOutput { protected def newWriterThread( env: SparkEnv, @@ -81,11 +82,11 @@ class CogroupedArrowPythonRunner( dataOut.writeInt(0) } - def writeGroup( + private def writeGroup( group: Iterator[InternalRow], schema: StructType, dataOut: DataOutputStream, - name: String) = { + name: String): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index cc83e0cecdc3..b079405bdc2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.python.PandasGroupUtils._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** @@ -52,7 +54,14 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends BasePandasGroupExec(func, output) with BinaryExecNode { + extends SparkPlan with BinaryExecNode { + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) override def outputPartitioning: Partitioning = left.outputPartitioning @@ -81,7 +90,7 @@ case class FlatMapCoGroupsInPandasExec( val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) .map { case (_, l, r) => (l, r) } - val runner = new CogroupedArrowPythonRunner( + val runner = new CoGroupedArrowPythonRunner( chainedFunc, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, Array(leftArgOffsets ++ rightArgOffsets), @@ -90,7 +99,7 @@ case class FlatMapCoGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf) - executePython(data, runner) + executePython(data, output, runner) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 22a0d1e09b12..5032bc81327b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** @@ -48,7 +50,14 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends BasePandasGroupExec(func, output) with UnaryExecNode { + extends SparkPlan with UnaryExecNode { + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -72,7 +81,7 @@ case class FlatMapGroupsInPandasExec( inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes) - .map{case(_, x) => x} + .map { case (_, x) => x } val runner = new ArrowPythonRunner( chainedFunc, @@ -82,7 +91,7 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf) - executePython(data, runner) + executePython(data, output, runner) }} } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala similarity index 85% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 477c288ad121..68ce991a8ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -21,37 +21,23 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} -import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** * Base functionality for plans which execute grouped python udfs. */ -abstract class BasePandasGroupExec( - func: Expression, - output: Seq[Attribute]) - extends SparkPlan { - - protected val sessionLocalTimeZone = conf.sessionLocalTimeZone - - protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - - protected val pandasFunction = func.asInstanceOf[PythonUDF].func - - protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - - override def producedAttributes: AttributeSet = AttributeSet(output) - +private[python] object PandasGroupUtils { /** * passes the data to the python runner and coverts the resulting * columnarbatch into internal rows. */ - protected def executePython[T]( + def executePython[T]( data: Iterator[T], + output: Seq[Attribute], runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = { val context = TaskContext.get() @@ -71,7 +57,7 @@ abstract class BasePandasGroupExec( /** * groups according to grouping attributes and then projects into the deduplicated schema */ - protected def groupAndProject( + def groupAndProject( input: Iterator[InternalRow], groupingAttributes: Seq[Attribute], inputSchema: Seq[Attribute], @@ -101,7 +87,7 @@ abstract class BasePandasGroupExec( * * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ - protected def resolveArgOffsets( + def resolveArgOffsets( child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { val dataAttributes = child.output.drop(groupingAttributes.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 0cee7d2f96c2..bb353062384a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.sql.execution.python -import java.io._ -import java.net._ +import java.io.DataInputStream +import java.net.Socket import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ @@ -25,22 +25,19 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.spark._ -import org.apache.spark.api.python._ +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** - * Common functionality for a udf runner that exchanges data with Python worker via Arrow stream. + * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from + * Python (Arrow) to JVM (ColumnarBatch). */ -abstract class BaseArrowPythonRunner[T]( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]]) - extends BasePythonRunner[T, ColumnarBatch](funcs, evalType, argOffsets) { +private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] => - protected override def newReaderIterator( + protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread, startTime: Long,