diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index b376515bf1af0..5f59ada38b658 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -69,7 +69,7 @@ message Response { // Result type oneof result_type { - ArrowBatch batch = 2; + ArrowBatch arrow_batch = 2; JSONBatch json_batch = 3; } @@ -80,10 +80,7 @@ message Response { // Batch results of metrics. message ArrowBatch { int64 row_count = 1; - int64 uncompressed_bytes = 2; - int64 compressed_bytes = 3; - bytes data = 4; - bytes schema = 5; + bytes data = 2; } // Message type when the result is returned as JSON. This is essentially a bulk wrapper diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 58fc6237867c0..3b734616b2138 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -29,8 +29,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.execution.arrow.ArrowConverters class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { @@ -48,19 +49,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } - def handlePlan(session: SparkSession, request: proto.Request): Unit = { + def handlePlan(session: SparkSession, request: Request): Unit = { // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) - val rows = - Dataset.ofRows(session, planner.transform()) - processRows(request.getClientId, rows) + val dataframe = Dataset.ofRows(session, planner.transform()) + try { + processAsArrowBatches(request.getClientId, dataframe) + } catch { + case e: Exception => + logWarning(e.getMessage) + processAsJsonBatches(request.getClientId, dataframe) + } } - def processRows(clientId: String, rows: DataFrame): Unit = { + def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { // Only process up to 10MB of data. val sb = new StringBuilder var rowCount = 0 - rows.toJSON + dataframe.toJSON .collect() .foreach(row => { @@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onNext(response.build()) } - responseObserver.onNext(sendMetricsToResponse(clientId, rows)) + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) responseObserver.onCompleted() } + def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + val spark = dataframe.sparkSession + val schema = dataframe.schema + // TODO: control the batch size instead of max records + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + + SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { + val rows = dataframe.queryExecution.executedPlan.execute() + val numPartitions = rows.getNumPartitions + var numSent = 0 + + if (numPartitions > 0) { + type Batch = (Array[Byte], Long) + + val batches = rows.mapPartitionsInternal { iter => + ArrowConverters + .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + } + + val signal = new Object + val partitions = collection.mutable.Map.empty[Int, Array[Batch]] + + val processPartition = (iter: Iterator[Batch]) => iter.toArray + + // This callback is executed by the DAGScheduler thread. + // After fetching a partition, it inserts the partition into the Map, and then + // wakes up the main thread. + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + signal.synchronized { + partitions(partitionId) = partition + signal.notify() + } + () + } + + spark.sparkContext.runJob(batches, processPartition, resultHandler) + + // The man thread will wait until 0-th partition is available, + // then send it to client and wait for next partition. + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + val partition = signal.synchronized { + while (!partitions.contains(currentPartitionId)) { + signal.wait() + } + partitions.remove(currentPartitionId).get + } + + partition.foreach { case (bytes, count) => + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + numSent += 1 + } + + currentPartitionId += 1 + } + } + + // Make sure at least 1 batch will be sent. + if (numSent == 0) { + val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setRowCount(0L) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + } + + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) + responseObserver.onCompleted() + } + } + def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = { // Send a last batch with the metrics Response diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 2eba9ac11f525..27075ff3cb027 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -377,8 +377,8 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: import pandas as pd - if b.batch is not None and len(b.batch.data) > 0: - with pa.ipc.open_stream(b.batch.data) as rd: + if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: + with pa.ipc.open_stream(b.arrow_batch.data) as rd: return rd.read_pandas() elif b.json_batch is not None and len(b.json_batch.data) > 0: return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) @@ -400,6 +400,13 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra if len(result_dfs) > 0: df = pd.concat(result_dfs) + + # pd.concat generates non-consecutive index like: + # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') + # set it to RangeIndex to be consistent with pyspark + n = len(df) + df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True) + # Attach the metrics to the DataFrame attributes. if m is not None: df.attrs["metrics"] = self._build_metrics(m) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index eb9ecc9157f2c..1f577089d1a29 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -53,21 +53,21 @@ _REQUEST_USERCONTEXT._serialized_start = 429 _REQUEST_USERCONTEXT._serialized_end = 551 _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1522 - _RESPONSE_ARROWBATCH._serialized_start = 783 - _RESPONSE_ARROWBATCH._serialized_end = 958 - _RESPONSE_JSONBATCH._serialized_start = 960 - _RESPONSE_JSONBATCH._serialized_end = 1020 - _RESPONSE_METRICS._serialized_start = 1023 - _RESPONSE_METRICS._serialized_end = 1507 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507 - _ANALYZERESPONSE._serialized_start = 1525 - _ANALYZERESPONSE._serialized_end = 1659 - _SPARKCONNECTSERVICE._serialized_start = 1662 - _SPARKCONNECTSERVICE._serialized_end = 1824 + _RESPONSE._serialized_end = 1418 + _RESPONSE_ARROWBATCH._serialized_start = 793 + _RESPONSE_ARROWBATCH._serialized_end = 854 + _RESPONSE_JSONBATCH._serialized_start = 856 + _RESPONSE_JSONBATCH._serialized_end = 916 + _RESPONSE_METRICS._serialized_start = 919 + _RESPONSE_METRICS._serialized_end = 1403 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1003 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1313 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1201 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1313 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1315 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1403 + _ANALYZERESPONSE._serialized_start = 1421 + _ANALYZERESPONSE._serialized_end = 1555 + _SPARKCONNECTSERVICE._serialized_start = 1558 + _SPARKCONNECTSERVICE._serialized_end = 1720 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 5ffd7701b440d..bf6d080d9fd97 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -178,38 +178,17 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor ROW_COUNT_FIELD_NUMBER: builtins.int - UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int - COMPRESSED_BYTES_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int - SCHEMA_FIELD_NUMBER: builtins.int row_count: builtins.int - uncompressed_bytes: builtins.int - compressed_bytes: builtins.int data: builtins.bytes - schema: builtins.bytes def __init__( self, *, row_count: builtins.int = ..., - uncompressed_bytes: builtins.int = ..., - compressed_bytes: builtins.int = ..., data: builtins.bytes = ..., - schema: builtins.bytes = ..., ) -> None: ... def ClearField( - self, - field_name: typing_extensions.Literal[ - "compressed_bytes", - b"compressed_bytes", - "data", - b"data", - "row_count", - b"row_count", - "schema", - b"schema", - "uncompressed_bytes", - b"uncompressed_bytes", - ], + self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... class JSONBatch(google.protobuf.message.Message): @@ -339,12 +318,12 @@ class Response(google.protobuf.message.Message): ) -> None: ... CLIENT_ID_FIELD_NUMBER: builtins.int - BATCH_FIELD_NUMBER: builtins.int + ARROW_BATCH_FIELD_NUMBER: builtins.int JSON_BATCH_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def batch(self) -> global___Response.ArrowBatch: ... + def arrow_batch(self) -> global___Response.ArrowBatch: ... @property def json_batch(self) -> global___Response.JSONBatch: ... @property @@ -356,15 +335,15 @@ class Response(google.protobuf.message.Message): self, *, client_id: builtins.str = ..., - batch: global___Response.ArrowBatch | None = ..., + arrow_batch: global___Response.ArrowBatch | None = ..., json_batch: global___Response.JSONBatch | None = ..., metrics: global___Response.Metrics | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "batch", - b"batch", + "arrow_batch", + b"arrow_batch", "json_batch", b"json_batch", "metrics", @@ -376,8 +355,8 @@ class Response(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "batch", - b"batch", + "arrow_batch", + b"arrow_batch", "client_id", b"client_id", "json_batch", @@ -390,7 +369,7 @@ class Response(google.protobuf.message.Message): ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["batch", "json_batch"] | None: ... + ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... global___Response = Response diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a0f046907f73e..38c244bd74bfa 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -197,6 +197,18 @@ def test_range(self): .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) ) + def test_empty_dataset(self): + # SPARK-41005: Test arrow based collection with empty dataset. + self.assertTrue( + self.connect.sql("SELECT 1 AS X LIMIT 0") + .toPandas() + .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas()) + ) + pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas() + self.assertEqual(0, len(pdf)) # empty dataset + self.assertEqual(1, len(pdf.columns)) # one column + self.assertEqual("X", pdf.columns[0]) + def test_simple_datasource_read(self) -> None: writeDf = self.df_text tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bded158645cce..a2dce31bc6d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -128,6 +128,92 @@ private[sql] object ArrowConverters extends Logging { } } + private[sql] def toArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int, + timeZoneId: String): Iterator[(Array[Byte], Long)] = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "toArrowBatchIterator", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) + val arrowWriter = ArrowWriter.create(root) + + Option(TaskContext.get).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + new Iterator[(Array[Byte], Long)] { + + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + false + } + + override def next(): (Array[Byte], Long) = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + var rowCount = 0L + Utils.tryWithSafeFinally { + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + + MessageSerializer.serialize(writeChannel, arrowSchema) + MessageSerializer.serialize(writeChannel, batch) + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + (out.toByteArray, rowCount) + } + } + } + + private[sql] def createEmptyArrowBatch( + schema: StructType, + timeZoneId: String): Array[Byte] = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "createEmptyArrowBatch", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) + val arrowWriter = ArrowWriter.create(root) + + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + Utils.tryWithSafeFinally { + arrowWriter.finish() + val batch = unloader.getRecordBatch() // empty batch + + MessageSerializer.serialize(writeChannel, arrowSchema) + MessageSerializer.serialize(writeChannel, batch) + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + out.toByteArray + } + /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. */