diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index 8c336b6fa6d5..450ff8ca6249 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -26,12 +26,14 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ +import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, ForwardingClientCall, ForwardingClientCallListener, MethodDescriptor} import org.apache.commons.io.output.TeeOutputStream import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.connect.proto import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException} @@ -41,7 +43,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper} -import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port +import org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, port} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ @@ -1848,6 +1850,161 @@ class ClientE2ETestSuite checkAnswer(df, Seq.empty) } } + + // Helper class to capture Arrow batch chunk information from gRPC responses + private class ArrowBatchInterceptor extends ClientInterceptor { + case class BatchInfo( + batchIndex: Int, + rowCount: Long, + startOffset: Long, + chunks: Seq[ChunkInfo]) { + def totalChunks: Int = chunks.length + } + + case class ChunkInfo( + batchIndex: Int, + chunkIndex: Int, + numChunksInBatch: Int, + rowCount: Long, + startOffset: Long, + dataSize: Int) + + private val batches: mutable.Buffer[BatchInfo] = mutable.Buffer.empty + private var currentBatchIndex: Int = 0 + private val currentBatchChunks: mutable.Buffer[ChunkInfo] = mutable.Buffer.empty + + override def interceptCall[ReqT, RespT]( + method: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + next: Channel): ClientCall[ReqT, RespT] = { + new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT]( + next.newCall(method, callOptions)) { + override def start( + responseListener: ClientCall.Listener[RespT], + headers: io.grpc.Metadata): Unit = { + super.start( + new ForwardingClientCallListener.SimpleForwardingClientCallListener[RespT]( + responseListener) { + override def onMessage(message: RespT): Unit = { + message match { + case response: proto.ExecutePlanResponse if response.hasArrowBatch => + val arrowBatch = response.getArrowBatch + // Track chunk information for every chunk + currentBatchChunks += ChunkInfo( + batchIndex = currentBatchIndex, + chunkIndex = arrowBatch.getChunkIndex.toInt, + numChunksInBatch = arrowBatch.getNumChunksInBatch.toInt, + rowCount = arrowBatch.getRowCount, + startOffset = arrowBatch.getStartOffset, + dataSize = arrowBatch.getData.size()) + // When we receive the last chunk, create the BatchInfo + if (currentBatchChunks.length == arrowBatch.getNumChunksInBatch) { + batches += BatchInfo( + batchIndex = currentBatchIndex, + rowCount = arrowBatch.getRowCount, + startOffset = arrowBatch.getStartOffset, + chunks = currentBatchChunks.toList) + currentBatchChunks.clear() + currentBatchIndex += 1 + } + case _ => // Not an ExecutePlanResponse with ArrowBatch, ignore + } + super.onMessage(message) + } + }, + headers) + } + } + } + + // Get all batch information + def getBatchInfos: Seq[BatchInfo] = batches.toSeq + + def clear(): Unit = { + currentBatchIndex = 0 + currentBatchChunks.clear() + batches.clear() + } + } + + test("Arrow batch result chunking") { + // This test validates that the client can correctly reassemble chunked Arrow batches + // using SequenceInputStream as implemented in SparkResult.processResponses + + // Two cases are tested here: + // (a) client preferred chunk size is set: the server should respect it + // (b) client preferred chunk size is not set: the server should use its own max chunk size + Seq((Some(1024), None), (None, Some(1024))).foreach { + case (preferredChunkSizeOpt, maxChunkSizeOpt) => + // Create interceptor to capture chunk information + val arrowBatchInterceptor = new ArrowBatchInterceptor() + + try { + // Set preferred chunk size if specified and add interceptor + preferredChunkSizeOpt match { + case Some(size) => + spark = createSparkSession( + _.preferredArrowChunkSize(Some(size)).interceptor(arrowBatchInterceptor)) + case None => + spark = createSparkSession(_.interceptor(arrowBatchInterceptor)) + } + // Set server max chunk size if specified + maxChunkSizeOpt.foreach { size => + spark.conf.set("spark.connect.session.resultChunking.maxChunkSize", size.toString) + } + + val sqlQuery = + "select id, CAST(id + 0.5 AS DOUBLE) as double_val from range(0, 2000, 1, 4)" + + // Execute the query using withResult to access SparkResult object + spark.sql(sqlQuery).withResult { result => + // Verify the results are correct and complete + assert(result.length == 2000) + + // Get batch information from interceptor + val batchInfos = arrowBatchInterceptor.getBatchInfos + + // Assert there are 4 batches (partitions) in total + assert(batchInfos.length == 4) + + // Validate chunk information for each batch + val maxChunkSize = preferredChunkSizeOpt.orElse(maxChunkSizeOpt).get + batchInfos.foreach { batch => + // In this example, the max chunk size is set to a small value, + // so each Arrow batch should be split into multiple chunks + assert(batch.totalChunks > 5) + assert(batch.chunks.nonEmpty) + assert(batch.chunks.length == batch.totalChunks) + batch.chunks.zipWithIndex.foreach { case (chunk, expectedIndex) => + assert(chunk.chunkIndex == expectedIndex) + assert(chunk.numChunksInBatch == batch.totalChunks) + assert(chunk.rowCount == batch.rowCount) + assert(chunk.startOffset == batch.startOffset) + assert(chunk.dataSize > 0) + assert(chunk.dataSize <= maxChunkSize) + } + } + + // Validate data integrity across the range to ensure chunking didn't corrupt anything + val rows = result.toArray + var expectedId = 0L + rows.foreach { row => + assert(row.getLong(0) == expectedId) + val expectedDouble = expectedId + 0.5 + val actualDouble = row.getDouble(1) + assert(math.abs(actualDouble - expectedDouble) < 0.001) + expectedId += 1 + } + } + } finally { + // Clean up configurations + maxChunkSizeOpt.foreach { _ => + spark.conf.unset("spark.connect.session.resultChunking.maxChunkSize") + } + arrowBatchInterceptor.clear() + } + } + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala index efb6c721876c..6d8d2edcf082 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala @@ -180,26 +180,35 @@ object SparkConnectServerUtils { val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1) fileName.endsWith(".jar") && (fileName.startsWith("scalatest") || fileName.startsWith("scalactic") || - (fileName.startsWith("spark-catalyst") && fileName.endsWith("-tests"))) + (fileName.startsWith("spark-catalyst") && fileName.endsWith("-tests")) || + fileName.startsWith("grpc-")) } .map(e => Paths.get(e).toUri) spark.client.artifactManager.addArtifacts(jars.toImmutableArraySeq) } def createSparkSession(): SparkSession = { + createSparkSession(identity) + } + + def createSparkSession( + customBuilderFunc: SparkConnectClient.Builder => SparkConnectClient.Builder) + : SparkSession = { SparkConnectServerUtils.start() + var builder = SparkConnectClient + .builder() + .userId("test") + .port(port) + .retryPolicy( + RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + + builder = customBuilderFunc(builder) val spark = SparkSession .builder() - .client( - SparkConnectClient - .builder() - .userId("test") - .port(port) - .retryPolicy(RetryPolicy - .defaultPolicy() - .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) - .build()) + .client(builder.build()) .create() // Execute an RPC which will get retried until the server is up. diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index fa32eba91eb2..e5fd16a7c261 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -138,6 +138,22 @@ private[sql] class SparkConnectClient( .setSessionId(sessionId) .setClientType(userAgent) .addAllTags(tags.get.toSeq.asJava) + + // Add request option to allow result chunking. + if (configuration.allowArrowBatchChunking) { + val chunkingOptionsBuilder = proto.ResultChunkingOptions + .newBuilder() + .setAllowArrowBatchChunking(true) + configuration.preferredArrowChunkSize.foreach { size => + chunkingOptionsBuilder.setPreferredArrowChunkSize(size) + } + request.addRequestOptions( + proto.ExecutePlanRequest.RequestOption + .newBuilder() + .setResultChunkingOptions(chunkingOptionsBuilder.build()) + .build()) + } + serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) operationId.foreach { opId => require( @@ -332,6 +348,16 @@ private[sql] class SparkConnectClient( def copy(): SparkConnectClient = configuration.toSparkConnectClient + /** + * Returns whether arrow batch chunking is allowed. + */ + def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking + + /** + * Returns the preferred arrow chunk size in bytes. + */ + def preferredArrowChunkSize: Option[Int] = configuration.preferredArrowChunkSize + /** * Add a single artifact to the client session. * @@ -757,6 +783,21 @@ object SparkConnectClient { this } + def allowArrowBatchChunking(allow: Boolean): Builder = { + _configuration = _configuration.copy(allowArrowBatchChunking = allow) + this + } + + def allowArrowBatchChunking: Boolean = _configuration.allowArrowBatchChunking + + def preferredArrowChunkSize(size: Option[Int]): Builder = { + size.foreach(s => require(s > 0, "preferredArrowChunkSize must be positive")) + _configuration = _configuration.copy(preferredArrowChunkSize = size) + this + } + + def preferredArrowChunkSize: Option[Int] = _configuration.preferredArrowChunkSize + def build(): SparkConnectClient = _configuration.toSparkConnectClient } @@ -801,7 +842,9 @@ object SparkConnectClient { interceptors: List[ClientInterceptor] = List.empty, sessionId: Option[String] = None, grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE, - grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) { + grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, + allowArrowBatchChunking: Boolean = true, + preferredArrowChunkSize: Option[Int] = None) { private def isLocal = host.equals("localhost") diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index ef55edd10c8a..43265e55a0ca 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -16,18 +16,21 @@ */ package org.apache.spark.sql.connect.client +import java.io.SequenceInputStream import java.lang.ref.Cleaner import java.util.Objects import scala.collection.mutable import scala.jdk.CollectionConverters._ +import com.google.protobuf.ByteString import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics +import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} @@ -42,7 +45,8 @@ private[sql] class SparkResult[T]( allocator: BufferAllocator, encoder: AgnosticEncoder[T], timeZoneId: String) - extends AutoCloseable { self => + extends AutoCloseable + with Logging { self => case class StageInfo( stageId: Long, @@ -118,6 +122,7 @@ private[sql] class SparkResult[T]( stopOnFirstNonEmptyResponse: Boolean = false): Boolean = { var nonEmpty = false var stop = false + val arrowBatchChunksToAssemble = mutable.Buffer.empty[ByteString] while (!stop && responses.hasNext) { val response = responses.next() @@ -151,55 +156,96 @@ private[sql] class SparkResult[T]( stop |= stopOnSchema } if (response.hasArrowBatch) { - val ipcStreamBytes = response.getArrowBatch.getData - val expectedNumRows = response.getArrowBatch.getRowCount - val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator) - if (arrowSchema == null) { - arrowSchema = reader.schema - stop |= stopOnArrowSchema - } else if (arrowSchema != reader.schema) { - throw new IllegalStateException( - s"""Schema Mismatch between expected and received schema: - |=== Expected Schema === - |$arrowSchema - |=== Received Schema === - |${reader.schema} - |""".stripMargin) - } - if (structType == null) { - // If the schema is not available yet, fallback to the arrow schema. - structType = ArrowUtils.fromArrowSchema(reader.schema) - } - if (response.getArrowBatch.hasStartOffset) { - val expectedStartOffset = response.getArrowBatch.getStartOffset - if (numRecords != expectedStartOffset) { + val arrowBatch = response.getArrowBatch + logDebug( + s"Received arrow batch rows=${arrowBatch.getRowCount} " + + s"Number of chunks in batch=${arrowBatch.getNumChunksInBatch} " + + s"Chunk index=${arrowBatch.getChunkIndex} " + + s"size=${arrowBatch.getData.size()}") + + if (arrowBatchChunksToAssemble.nonEmpty) { + // Expect next chunk of the same batch + if (arrowBatch.getChunkIndex != arrowBatchChunksToAssemble.size) { throw new IllegalStateException( - s"Expected arrow batch to start at row offset $numRecords in results, " + - s"but received arrow batch starting at offset $expectedStartOffset.") + s"Expected chunk index ${arrowBatchChunksToAssemble.size} of the " + + s"arrow batch but got ${arrowBatch.getChunkIndex}.") } - } - var numRecordsInBatch = 0 - val messages = Seq.newBuilder[ArrowMessage] - while (reader.hasNext) { - val message = reader.next() - message match { - case batch: ArrowRecordBatch => - numRecordsInBatch += batch.getLength - case _ => + } else { + // Expect next batch + if (arrowBatch.hasStartOffset) { + val expectedStartOffset = arrowBatch.getStartOffset + if (numRecords != expectedStartOffset) { + throw new IllegalStateException( + s"Expected arrow batch to start at row offset $numRecords in results, " + + s"but received arrow batch starting at offset $expectedStartOffset.") + } + } + if (arrowBatch.getChunkIndex != 0) { + throw new IllegalStateException( + s"Expected chunk index 0 of the next arrow batch " + + s"but got ${arrowBatch.getChunkIndex}.") } - messages += message - } - if (numRecordsInBatch != expectedNumRows) { - throw new IllegalStateException( - s"Expected $expectedNumRows rows in arrow batch but got $numRecordsInBatch.") } - // Skip the entire result if it is empty. - if (numRecordsInBatch > 0) { - numRecords += numRecordsInBatch - resultMap.put(nextResultIndex, (reader.bytesRead, messages.result())) - nextResultIndex += 1 - nonEmpty |= true - stop |= stopOnFirstNonEmptyResponse + + arrowBatchChunksToAssemble += arrowBatch.getData + + // Assemble the chunks to an arrow batch to process if + // (a) chunking is not enabled (numChunksInBatch is not set or is 0, + // in this case, it is the single chunk in the batch) + // (b) or the client has received all chunks of the batch. + if (!arrowBatch.hasNumChunksInBatch || + arrowBatch.getNumChunksInBatch == 0 || + arrowBatchChunksToAssemble.size == arrowBatch.getNumChunksInBatch) { + + val numChunks = arrowBatchChunksToAssemble.size + val inputStreams = + arrowBatchChunksToAssemble.map(_.newInput()).iterator.asJavaEnumeration + val input = new SequenceInputStream(inputStreams) + arrowBatchChunksToAssemble.clear() + logDebug(s"Assembling arrow batch from $numChunks chunks.") + + val expectedNumRows = arrowBatch.getRowCount + val reader = new MessageIterator(input, allocator) + if (arrowSchema == null) { + arrowSchema = reader.schema + stop |= stopOnArrowSchema + } else if (arrowSchema != reader.schema) { + throw new IllegalStateException( + s"""Schema Mismatch between expected and received schema: + |=== Expected Schema === + |$arrowSchema + |=== Received Schema === + |${reader.schema} + |""".stripMargin) + } + if (structType == null) { + // If the schema is not available yet, fallback to the arrow schema. + structType = ArrowUtils.fromArrowSchema(reader.schema) + } + + var numRecordsInBatch = 0 + val messages = Seq.newBuilder[ArrowMessage] + while (reader.hasNext) { + val message = reader.next() + message match { + case batch: ArrowRecordBatch => + numRecordsInBatch += batch.getLength + case _ => + } + messages += message + } + if (numRecordsInBatch != expectedNumRows) { + throw new IllegalStateException( + s"Expected $expectedNumRows rows in arrow batch but got $numRecordsInBatch.") + } + // Skip the entire result if it is empty. + if (numRecordsInBatch > 0) { + numRecords += numRecordsInBatch + resultMap.put(nextResultIndex, (reader.bytesRead, messages.result())) + nextResultIndex += 1 + nonEmpty |= true + stop |= stopOnFirstNonEmptyResponse + } } } }