-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41005][CONNECT][PYTHON] Arrow-based collect #38468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,8 +29,9 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} | ||
| import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner | ||
| import org.apache.spark.sql.connect.planner.SparkConnectPlanner | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} | ||
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} | ||
| import org.apache.spark.sql.execution.arrow.ArrowConverters | ||
|
|
||
| class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { | ||
|
|
||
|
|
@@ -48,19 +49,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte | |
| } | ||
| } | ||
|
|
||
| def handlePlan(session: SparkSession, request: proto.Request): Unit = { | ||
| def handlePlan(session: SparkSession, request: Request): Unit = { | ||
| // Extract the plan from the request and convert it to a logical plan | ||
| val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) | ||
| val rows = | ||
| Dataset.ofRows(session, planner.transform()) | ||
| processRows(request.getClientId, rows) | ||
| val dataframe = Dataset.ofRows(session, planner.transform()) | ||
| try { | ||
| processAsArrowBatches(request.getClientId, dataframe) | ||
| } catch { | ||
| case e: Exception => | ||
| logWarning(e.getMessage) | ||
| processAsJsonBatches(request.getClientId, dataframe) | ||
| } | ||
| } | ||
|
|
||
| def processRows(clientId: String, rows: DataFrame): Unit = { | ||
| def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { | ||
| // Only process up to 10MB of data. | ||
| val sb = new StringBuilder | ||
| var rowCount = 0 | ||
| rows.toJSON | ||
| dataframe.toJSON | ||
| .collect() | ||
| .foreach(row => { | ||
|
|
||
|
|
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte | |
| responseObserver.onNext(response.build()) | ||
| } | ||
|
|
||
| responseObserver.onNext(sendMetricsToResponse(clientId, rows)) | ||
| responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) | ||
| responseObserver.onCompleted() | ||
| } | ||
|
|
||
| def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { | ||
| val spark = dataframe.sparkSession | ||
| val schema = dataframe.schema | ||
| // TODO: control the batch size instead of max records | ||
| val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch | ||
| val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone | ||
|
|
||
| SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { | ||
| val rows = dataframe.queryExecution.executedPlan.execute() | ||
| val numPartitions = rows.getNumPartitions | ||
| var numSent = 0 | ||
|
|
||
| if (numPartitions > 0) { | ||
| type Batch = (Array[Byte], Long) | ||
|
|
||
| val batches = rows.mapPartitionsInternal { iter => | ||
| ArrowConverters | ||
| .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) | ||
| } | ||
|
|
||
| val signal = new Object | ||
| val partitions = collection.mutable.Map.empty[Int, Array[Batch]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need a map? We know the number of partitions and we can just create an array.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then we don't need a lock.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, this is kind of async collect.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just change from array to map ... see #38468 (comment)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably add some comments at the beginning to explain the overall workflow.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we apply the same idea to JSON batches?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think so, let's optimize it later |
||
|
|
||
| val processPartition = (iter: Iterator[Batch]) => iter.toArray | ||
|
|
||
| // This callback is executed by the DAGScheduler thread. | ||
| // After fetching a partition, it inserts the partition into the Map, and then | ||
| // wakes up the main thread. | ||
| val resultHandler = (partitionId: Int, partition: Array[Batch]) => { | ||
| signal.synchronized { | ||
| partitions(partitionId) = partition | ||
| signal.notify() | ||
| } | ||
| () | ||
| } | ||
|
|
||
| spark.sparkContext.runJob(batches, processPartition, resultHandler) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like it had to be (async) |
||
|
|
||
| // The man thread will wait until 0-th partition is available, | ||
| // then send it to client and wait for next partition. | ||
| var currentPartitionId = 0 | ||
| while (currentPartitionId < numPartitions) { | ||
| val partition = signal.synchronized { | ||
| while (!partitions.contains(currentPartitionId)) { | ||
| signal.wait() | ||
|
||
| } | ||
| partitions.remove(currentPartitionId).get | ||
| } | ||
|
|
||
| partition.foreach { case (bytes, count) => | ||
| val response = proto.Response.newBuilder().setClientId(clientId) | ||
| val batch = proto.Response.ArrowBatch | ||
| .newBuilder() | ||
| .setRowCount(count) | ||
| .setData(ByteString.copyFrom(bytes)) | ||
| .build() | ||
| response.setArrowBatch(batch) | ||
| responseObserver.onNext(response.build()) | ||
hvanhovell marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| numSent += 1 | ||
|
||
| } | ||
|
|
||
| currentPartitionId += 1 | ||
| } | ||
| } | ||
|
|
||
| // Make sure at least 1 batch will be sent. | ||
| if (numSent == 0) { | ||
zhengruifeng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) | ||
| val response = proto.Response.newBuilder().setClientId(clientId) | ||
| val batch = proto.Response.ArrowBatch | ||
| .newBuilder() | ||
| .setRowCount(0L) | ||
| .setData(ByteString.copyFrom(bytes)) | ||
| .build() | ||
| response.setArrowBatch(batch) | ||
| responseObserver.onNext(response.build()) | ||
| } | ||
|
|
||
| responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) | ||
| responseObserver.onCompleted() | ||
| } | ||
| } | ||
|
|
||
| def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = { | ||
| // Send a last batch with the metrics | ||
| Response | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,6 +197,18 @@ def test_range(self): | |
| .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) | ||
| ) | ||
|
|
||
| def test_empty_dataset(self): | ||
| # SPARK-41005: Test arrow based collection with empty dataset. | ||
| self.assertTrue( | ||
|
||
| self.connect.sql("SELECT 1 AS X LIMIT 0") | ||
| .toPandas() | ||
| .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas()) | ||
| ) | ||
| pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas() | ||
| self.assertEqual(0, len(pdf)) # empty dataset | ||
| self.assertEqual(1, len(pdf.columns)) # one column | ||
| self.assertEqual("X", pdf.columns[0]) | ||
|
|
||
| def test_simple_datasource_read(self) -> None: | ||
| writeDf = self.df_text | ||
| tmpPath = tempfile.mkdtemp() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove this since we're already handling the max records?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#38468 (comment) suggested control the batch size < 4MB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting. I only know that gRPC has a hard limits of 2GB/s transfer rate. Never know it might not favor over large messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is a throughput limit in GRPC itself.
The reason for the batching is that protobuf is not suited for this. Embedding large binary objects might require the reader to materialize them in memory.
Fixing this is an optimization for later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another downside of large allocations is that the GC does not really like them. All large allocation (> 1 MB) are generally placed in the old generation immediately, which requires a full GC to clean-up.