diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index e0e35dfbab96..20c831465e76 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -47,6 +47,7 @@ private[spark] object PythonEvalType { val SQL_GROUPED_AGG_PANDAS_UDF = 202 val SQL_WINDOW_AGG_PANDAS_UDF = 203 val SQL_SCALAR_PANDAS_ITER_UDF = 204 + val SQL_COGROUPED_MAP_PANDAS_UDF = 205 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -56,6 +57,7 @@ private[spark] object PythonEvalType { case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" + case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 395abc841827..dc285fac1c5a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -74,6 +74,7 @@ class PythonEvalType(object): SQL_GROUPED_AGG_PANDAS_UDF = 202 SQL_WINDOW_AGG_PANDAS_UDF = 203 SQL_SCALAR_PANDAS_ITER_UDF = 204 + SQL_COGROUPED_MAP_PANDAS_UDF = 205 def portable_hash(x): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ddca2a71828a..d8925b809173 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -356,6 +356,27 @@ def __repr__(self): return "ArrowStreamPandasSerializer" +class InterleavedArrowReader(object): + + def __init__(self, stream): + import pyarrow as pa + self._schema1 = pa.read_schema(stream) + self._schema2 = pa.read_schema(stream) + self._reader = pa.MessageReader.open_stream(stream) + + def __iter__(self): + return self + + def __next__(self): + import pyarrow as pa + batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1) + batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2) + return batch1, batch2 + + def next(self): + return self.__next__() + + class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ Serializer used by Python worker to evaluate Pandas UDFs @@ -401,6 +422,22 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" +class InterleavedArrowStreamPandasSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name): + super(InterleavedArrowStreamPandasSerializer, self).__init__(timezone, safecheck, assign_cols_by_name) + + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. + """ + import pyarrow as pa + reader = InterleavedArrowReader(pa.input_stream(stream)) + for batch1, batch2 in reader: + yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()], + [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()]) + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py new file mode 100644 index 000000000000..b758d1b7d8c4 --- /dev/null +++ b/python/pyspark/sql/cogroup.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.dataframe import DataFrame + + +class CoGroupedData(object): + + def __init__(self, gd1, gd2): + self._gd1 = gd1 + self._gd2 = gd2 + self.sql_ctx = gd1.sql_ctx + + def apply(self, udf): + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) + udf_column = udf(*all_cols) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + + @staticmethod + def _extract_cols(gd): + df = gd._df + return [df[col] for col in df.columns] + diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 34f65930553c..802bf714cd56 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2800,6 +2800,8 @@ class PandasUDFType(object): GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @@ -3178,6 +3180,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index cc1da8e7c1f7..04f42b159837 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -22,6 +22,7 @@ from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * +from pyspark.sql.cogroup import CoGroupedData __all__ = ["GroupedData"] @@ -220,6 +221,9 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) + def cogroup(self, other): + return CoGroupedData(self, other) + @since(2.3) def apply(self, udf): """ diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py new file mode 100644 index 000000000000..9104aa2193da --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest +import sys + +from collections import OrderedDict +from decimal import Decimal + +from pyspark.sql import Row +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + +if have_pandas: + import pandas as pd + from pandas.util.testing import assert_frame_equal + +if have_pyarrow: + import pyarrow as pa + + +""" +Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +""" +if sys.version < '3': + _check_column_type = False +else: + _check_column_type = True + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): + + @property + def data1(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks')))\ + .withColumn("v", col('k') * 10)\ + .drop('ks') + + @property + def data2(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks'))) \ + .withColumn("v2", col('k') * 100) \ + .drop('ks') + + def test_simple(self): + import pandas as pd + + l = self.data1 + r = self.data2 + + @pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP) + def merge_pandas(left, right): + return pd.merge(left, right, how='outer', on=['k', 'id']) + + result = l\ + .groupby('id')\ + .cogroup(r.groupby(r.id))\ + .apply(merge_pandas)\ + .sort(['id', 'k'])\ + .toPandas() + + expected = pd\ + .merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id']) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ee46bb649d1f..e42f4ac698c8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,7 +38,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasUDFSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -111,8 +111,25 @@ def verify_result_length(result, length): map(verify_result_type, f(*iterator))) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_cogrouped_map_pandas_udf(f, return_type): + def wrapped(left, right): + import pandas as pd + result = f(pd.concat(left, axis=1), pd.concat(right, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result + + return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))] + + +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -232,6 +249,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -246,6 +265,7 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -269,10 +289,13 @@ def read_udfs(pickleSer, infile, eval_type): # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. - df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) + else: + df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) - ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, - df_for_struct) + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) @@ -343,6 +366,14 @@ def map_batch(batch): arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) + udfs['f'] = udf + mapper_str = "lambda a: f(a)" else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 60517f11a249..9dad538686b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -970,6 +970,10 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) => + val leftAttributes2 = leftAttributes.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute]) + val rightAttributes2 = rightAttributes.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute]) + f.copy(leftAttributes=leftAttributes2, rightAttributes=rightAttributes2) case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => @@ -2269,6 +2273,7 @@ class Analyzer( } } + /** * Removes natural or using joins by calculating output columns based on output from two sides, * Then apply a Project on a normal Join to eliminate natural or using join. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 9f67700f1573..16b8f57038e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas( override val producedAttributes = AttributeSet(output) } + +case class FlatMapCoGroupsInPandas( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + override val producedAttributes = AttributeSet(output) +} + + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e85636d82a62..0018f6379e8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], + val df: DataFrame, + val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -523,6 +523,37 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def flatMapCoGroupsInPandas + (r: RelationalGroupedDataset, expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + "Must pass a cogrouped map udf") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val leftGroupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val rightGroupingNamedExpressions = r.groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) + val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) + + val leftChild = df.logicalPlan + val rightChild = r.df.logicalPlan + + val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild) + val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild) + + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c4775eb769b..56cd346015e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, p, b, is, ot, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil + case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) => + execution.python.FlatMapCoGroupsInPandasExec( + leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala new file mode 100644 index 000000000000..0305f7955778 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF, UnsafeProjection} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ + +trait AbstractPandasGroupExec extends SparkPlan { + + protected val sessionLocalTimeZone = conf.sessionLocalTimeZone + + protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + protected def chainedFunc = Seq( + ChainedPythonFunctions(Seq(func.asInstanceOf[PythonUDF].func))) + + def output: Seq[Attribute] + + def func: Expression + + protected def executePython[T](data: Iterator[T], + runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = { + + val context = TaskContext.get() + val columnarBatchIter = runner.compute(data, context.partitionId(), context) + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + + } + + protected def groupAndDedup( + input: Iterator[InternalRow], groupingAttributes: Seq[Attribute], + inputSchema: Seq[Attribute], dedupSchema: Seq[Attribute]): Iterator[Iterator[InternalRow]] = { + if (groupingAttributes.isEmpty) { + Iterator(input) + } else { + val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema) + val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) + } + } + } + + protected def createSchema(child: SparkPlan, groupingAttributes: Seq[Attribute]) + : (StructType, Seq[Attribute], Array[Array[Int]]) = { + + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + (dedupSchema, dedupAttributes, argOffsets) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5101f7e871af..2ff2517a90da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -46,7 +46,7 @@ class ArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + extends BaseArrowPythonRunner[Iterator[InternalRow]]( funcs, evalType, argOffsets) { override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize @@ -119,72 +119,4 @@ class ArrowPythonRunner( } } - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala new file mode 100644 index 000000000000..3cba06dcf7d5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + + +abstract class BaseArrowPythonRunner[T]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[T, ColumnarBatch]( + funcs, evalType, argOffsets) { + + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala new file mode 100644 index 000000000000..a70b4763fc0f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan} + +case class FlatMapCoGroupsInPandasExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends BinaryExecNode with AbstractPandasGroupExec { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + override protected def doExecute(): RDD[InternalRow] = { + + val (schemaLeft, attrLeft, _) = createSchema(left, leftGroup) + val (schemaRight, attrRight, _) = createSchema(right, rightGroup) + + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val projLeft = UnsafeProjection.create(attrLeft, left.output) + val projRight = UnsafeProjection.create(attrRight, right.output) + val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map{case (k, l, r) => (l.map(projLeft), r.map(projRight))} + + val runner = new InterleavedArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + Array(Array.empty), + schemaLeft, + schemaRight, + sessionLocalTimeZone, + pythonRunnerConf) + + executePython(data, runner) + + } + + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 267698d1bca5..474bbe04b8d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -53,9 +53,7 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { - - private val pandasFunction = func.asInstanceOf[PythonUDF].func + extends UnaryExecNode with AbstractPandasGroupExec { override def outputPartitioning: Partitioning = child.outputPartitioning @@ -75,88 +73,22 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - - // Deduplicate the grouping attributes. - // If a grouping attribute also appears in data attributes, then we don't need to send the - // grouping attribute to Python worker. If a grouping attribute is not in data attributes, - // then we need to send this grouping attribute to python worker. - // - // We use argOffsets to distinguish grouping attributes and data attributes as following: - // - // argOffsets[0] is the length of grouping attributes - // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes - // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes - - val dataAttributes = child.output.drop(groupingAttributes.length) - val groupingIndicesInData = groupingAttributes.map { attribute => - dataAttributes.indexWhere(attribute.semanticEquals) - } - - val groupingArgOffsets = new ArrayBuffer[Int] - val nonDupGroupingAttributes = new ArrayBuffer[Attribute] - val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) - - // Non duplicate grouping attributes are added to nonDupGroupingAttributes and - // their offsets are 0, 1, 2 ... - // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and - // their offsets are n + index, where n is the total number of non duplicate grouping - // attributes and index is the index in the data attributes that the grouping attribute - // is a duplicate of. - - groupingAttributes.zip(groupingIndicesInData).foreach { - case (attribute, index) => - if (index == -1) { - groupingArgOffsets += nonDupGroupingAttributes.length - nonDupGroupingAttributes += attribute - } else { - groupingArgOffsets += index + nonDupGroupingSize - } - } - - val dataArgOffsets = nonDupGroupingAttributes.length until - (nonDupGroupingAttributes.length + dataAttributes.length) - - val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) - - // Attributes after deduplication - val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes - val dedupSchema = StructType.fromAttributes(dedupAttributes) + val (dedupSchema, dedupAttributes, argOffsets) = createSchema(child, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dedupProj) - } - } - val context = TaskContext.get() + val data = groupAndDedup(iter, groupingAttributes, child.output, dedupAttributes) - val columnarBatchIter = new ArrowPythonRunner( + val runner = new ArrowPythonRunner( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(grouped, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) + pythonRunnerConf) - columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) + executePython(data, runner) }} } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala new file mode 100644 index 000000000000..b39885ee47a2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ + +import org.apache.arrow.vector.VectorSchemaRoot + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + + +class InterleavedArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + leftSchema: StructType, + rightSchema: StructType, + timeZoneId: String, + conf: Map[String, String]) + extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])]( + funcs, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val leftArrowSchema = ArrowUtils.toArrowSchema(leftSchema, timeZoneId) + val rightArrowSchema = ArrowUtils.toArrowSchema(rightSchema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val leftRoot = VectorSchemaRoot.create(leftArrowSchema, allocator) + val rightRoot = VectorSchemaRoot.create(rightArrowSchema, allocator) + + Utils.tryWithSafeFinally { + val leftArrowWriter = ArrowWriter.create(leftRoot) + val rightArrowWriter = ArrowWriter.create(rightRoot) + val writer = InterleavedArrowWriter(leftRoot, rightRoot, dataOut) + writer.start() + + while (inputIterator.hasNext) { + + val (nextLeft, nextRight) = inputIterator.next() + + while (nextLeft.hasNext) { + leftArrowWriter.write(nextLeft.next()) + } + while (nextRight.hasNext) { + rightArrowWriter.write(nextRight.next()) + } + leftArrowWriter.finish() + rightArrowWriter.finish() + writer.writeBatch() + leftArrowWriter.reset() + rightArrowWriter.reset() + } + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + leftRoot.close() + rightRoot.close() + allocator.close() + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala new file mode 100644 index 000000000000..eb9f1d4494b9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.OutputStream +import java.nio.channels.Channels + +import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.ipc.WriteChannel +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + + +class InterleavedArrowWriter( leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: WriteChannel) extends AutoCloseable{ + + + private var started = false + private val leftUnloader = new VectorUnloader(leftRoot) + private val rightUnloader = new VectorUnloader(rightRoot) + + def start(): Unit = { + this.ensureStarted() + } + + def writeBatch(): Unit = { + this.ensureStarted() + writeRecordBatch(leftUnloader.getRecordBatch) + writeRecordBatch(rightUnloader.getRecordBatch) + } + + private def writeRecordBatch(b: ArrowRecordBatch): Unit = { + try + MessageSerializer.serialize(out, b) + finally + b.close() + } + + private def ensureStarted(): Unit = { + if (!started) { + started = true + MessageSerializer.serialize(out, leftRoot.getSchema) + MessageSerializer.serialize(out, rightRoot.getSchema) + } + } + + def end(): Unit = { + ensureStarted() + ensureEnded() + } + + def ensureEnded(): Unit = { + out.writeIntLittleEndian(0) + } + + def close(): Unit = { + out.close() + } + +} + +object InterleavedArrowWriter{ + + def apply(leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: OutputStream): InterleavedArrowWriter = { + new InterleavedArrowWriter(leftRoot, rightRoot, new WriteChannel(Channels.newChannel(out))) + } + +}