Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters._

import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, ForwardingClientCall, ForwardingClientCallListener, MethodDescriptor}
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.connect.proto
import org.apache.spark.internal.config.ConfigBuilder
import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException}
Expand All @@ -41,7 +43,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, port}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1848,6 +1850,161 @@ class ClientE2ETestSuite
checkAnswer(df, Seq.empty)
}
}

// Helper class to capture Arrow batch chunk information from gRPC responses
private class ArrowBatchInterceptor extends ClientInterceptor {
case class BatchInfo(
batchIndex: Int,
rowCount: Long,
startOffset: Long,
chunks: Seq[ChunkInfo]) {
def totalChunks: Int = chunks.length
}

case class ChunkInfo(
batchIndex: Int,
chunkIndex: Int,
numChunksInBatch: Int,
rowCount: Long,
startOffset: Long,
dataSize: Int)

private val batches: mutable.Buffer[BatchInfo] = mutable.Buffer.empty
private var currentBatchIndex: Int = 0
private val currentBatchChunks: mutable.Buffer[ChunkInfo] = mutable.Buffer.empty

override def interceptCall[ReqT, RespT](
method: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
next: Channel): ClientCall[ReqT, RespT] = {
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
next.newCall(method, callOptions)) {
override def start(
responseListener: ClientCall.Listener[RespT],
headers: io.grpc.Metadata): Unit = {
super.start(
new ForwardingClientCallListener.SimpleForwardingClientCallListener[RespT](
responseListener) {
override def onMessage(message: RespT): Unit = {
message match {
case response: proto.ExecutePlanResponse if response.hasArrowBatch =>
val arrowBatch = response.getArrowBatch
// Track chunk information for every chunk
currentBatchChunks += ChunkInfo(
batchIndex = currentBatchIndex,
chunkIndex = arrowBatch.getChunkIndex.toInt,
numChunksInBatch = arrowBatch.getNumChunksInBatch.toInt,
rowCount = arrowBatch.getRowCount,
startOffset = arrowBatch.getStartOffset,
dataSize = arrowBatch.getData.size())
// When we receive the last chunk, create the BatchInfo
if (currentBatchChunks.length == arrowBatch.getNumChunksInBatch) {
batches += BatchInfo(
batchIndex = currentBatchIndex,
rowCount = arrowBatch.getRowCount,
startOffset = arrowBatch.getStartOffset,
chunks = currentBatchChunks.toList)
currentBatchChunks.clear()
currentBatchIndex += 1
}
case _ => // Not an ExecutePlanResponse with ArrowBatch, ignore
}
super.onMessage(message)
}
},
headers)
}
}
}

// Get all batch information
def getBatchInfos: Seq[BatchInfo] = batches.toSeq

def clear(): Unit = {
currentBatchIndex = 0
currentBatchChunks.clear()
batches.clear()
}
}

test("Arrow batch result chunking") {
// This test validates that the client can correctly reassemble chunked Arrow batches
// using SequenceInputStream as implemented in SparkResult.processResponses

// Two cases are tested here:
// (a) client preferred chunk size is set: the server should respect it
// (b) client preferred chunk size is not set: the server should use its own max chunk size
Seq((Some(1024), None), (None, Some(1024))).foreach {
case (preferredChunkSizeOpt, maxChunkSizeOpt) =>
// Create interceptor to capture chunk information
val arrowBatchInterceptor = new ArrowBatchInterceptor()

try {
// Set preferred chunk size if specified and add interceptor
preferredChunkSizeOpt match {
case Some(size) =>
spark = createSparkSession(
_.preferredArrowChunkSize(Some(size)).interceptor(arrowBatchInterceptor))
case None =>
spark = createSparkSession(_.interceptor(arrowBatchInterceptor))
}
// Set server max chunk size if specified
maxChunkSizeOpt.foreach { size =>
spark.conf.set("spark.connect.session.resultChunking.maxChunkSize", size.toString)
}

val sqlQuery =
"select id, CAST(id + 0.5 AS DOUBLE) as double_val from range(0, 2000, 1, 4)"

// Execute the query using withResult to access SparkResult object
spark.sql(sqlQuery).withResult { result =>
// Verify the results are correct and complete
assert(result.length == 2000)

// Get batch information from interceptor
val batchInfos = arrowBatchInterceptor.getBatchInfos

// Assert there are 4 batches (partitions) in total
assert(batchInfos.length == 4)

// Validate chunk information for each batch
val maxChunkSize = preferredChunkSizeOpt.orElse(maxChunkSizeOpt).get
batchInfos.foreach { batch =>
// In this example, the max chunk size is set to a small value,
// so each Arrow batch should be split into multiple chunks
assert(batch.totalChunks > 5)
assert(batch.chunks.nonEmpty)
assert(batch.chunks.length == batch.totalChunks)
batch.chunks.zipWithIndex.foreach { case (chunk, expectedIndex) =>
assert(chunk.chunkIndex == expectedIndex)
assert(chunk.numChunksInBatch == batch.totalChunks)
assert(chunk.rowCount == batch.rowCount)
assert(chunk.startOffset == batch.startOffset)
assert(chunk.dataSize > 0)
assert(chunk.dataSize <= maxChunkSize)
}
}

// Validate data integrity across the range to ensure chunking didn't corrupt anything
val rows = result.toArray
var expectedId = 0L
rows.foreach { row =>
assert(row.getLong(0) == expectedId)
val expectedDouble = expectedId + 0.5
val actualDouble = row.getDouble(1)
assert(math.abs(actualDouble - expectedDouble) < 0.001)
expectedId += 1
}
}
} finally {
// Clean up configurations
maxChunkSizeOpt.foreach { _ =>
spark.conf.unset("spark.connect.session.resultChunking.maxChunkSize")
}
arrowBatchInterceptor.clear()
}
}
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,27 @@ object SparkConnectServerUtils {
}

def createSparkSession(): SparkSession = {
createSparkSession(identity)
}

def createSparkSession(
customBuilderFunc: SparkConnectClient.Builder => SparkConnectClient.Builder)
: SparkSession = {
SparkConnectServerUtils.start()

var builder = SparkConnectClient
.builder()
.userId("test")
.port(port)
.retryPolicy(
RetryPolicy
.defaultPolicy()
.copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))

builder = customBuilderFunc(builder)
val spark = SparkSession
.builder()
.client(
SparkConnectClient
.builder()
.userId("test")
.port(port)
.retryPolicy(RetryPolicy
.defaultPolicy()
.copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))
.build())
.client(builder.build())
.create()

// Execute an RPC which will get retried until the server is up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)

// Add request option to allow result chunking.
if (configuration.allowArrowBatchChunking) {
val chunkingOptionsBuilder = proto.ResultChunkingOptions
.newBuilder()
.setAllowArrowBatchChunking(true)
configuration.preferredArrowChunkSize.foreach { size =>
chunkingOptionsBuilder.setPreferredArrowChunkSize(size)
}
request.addRequestOptions(
proto.ExecutePlanRequest.RequestOption
.newBuilder()
.setResultChunkingOptions(chunkingOptionsBuilder.build())
.build())
}

serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
operationId.foreach { opId =>
require(
Expand Down Expand Up @@ -332,6 +348,16 @@ private[sql] class SparkConnectClient(

def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
* Returns whether arrow batch chunking is allowed.
*/
def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking

/**
* Returns the preferred arrow chunk size in bytes.
*/
def preferredArrowChunkSize: Option[Int] = configuration.preferredArrowChunkSize

/**
* Add a single artifact to the client session.
*
Expand Down Expand Up @@ -757,6 +783,21 @@ object SparkConnectClient {
this
}

def allowArrowBatchChunking(allow: Boolean): Builder = {
_configuration = _configuration.copy(allowArrowBatchChunking = allow)
this
}

def allowArrowBatchChunking: Boolean = _configuration.allowArrowBatchChunking

def preferredArrowChunkSize(size: Option[Int]): Builder = {
size.foreach(s => require(s > 0, "preferredArrowChunkSize must be positive"))
_configuration = _configuration.copy(preferredArrowChunkSize = size)
this
}

def preferredArrowChunkSize: Option[Int] = _configuration.preferredArrowChunkSize

def build(): SparkConnectClient = _configuration.toSparkConnectClient
}

Expand Down Expand Up @@ -801,7 +842,9 @@ object SparkConnectClient {
interceptors: List[ClientInterceptor] = List.empty,
sessionId: Option[String] = None,
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
allowArrowBatchChunking: Boolean = true,
preferredArrowChunkSize: Option[Int] = None) {

private def isLocal = host.equals("localhost")

Expand Down
Loading