Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ jobs:
- name: Install Python packages (Python 3.8)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
run: |
python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==4.21.6'
python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==3.19.4'
python3.8 -m pip list
# Run the tests.
- name: Run tests
Expand Down Expand Up @@ -589,7 +589,7 @@ jobs:
# See also https://issues.apache.org/jira/browse/SPARK-38279.
python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
python3.9 -m pip install ipython_genutils # See SPARK-38517
python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==4.21.6' 'mypy-protobuf==3.3.0'
python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.4' 'mypy-protobuf==3.3.0'
python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
apt-get update -y
apt-get install -y ruby ruby-dev
Expand Down
14 changes: 1 addition & 13 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ message ExecutePlanRequest {
message ExecutePlanResponse {
string client_id = 1;

// Result type
oneof result_type {
ArrowBatch arrow_batch = 2;
JSONBatch json_batch = 3;
}
ArrowBatch arrow_batch = 2;

// Metrics for the query execution. Typically, this field is only present in the last
// batch of results and then represent the overall state of the query execution.
Expand All @@ -155,14 +151,6 @@ message ExecutePlanResponse {
bytes data = 2;
}

// Message type when the result is returned as JSON. This is essentially a bulk wrapper
// for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format
// of `{col -> row}`.
message JSONBatch {
int64 row_count = 1;
bytes data = 2;
}

message Metrics {

repeated MetricObject metrics = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
package org.apache.spark.sql.connect.service

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
Expand All @@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
extends Logging {
Expand All @@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
try {
processAsArrowBatches(request.getClientId, dataframe)
} catch {
case e: Exception =>
logWarning(e.getMessage)
processAsJsonBatches(request.getClientId, dataframe)
}
}

def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = {
// Only process up to 10MB of data.
val sb = new StringBuilder
var rowCount = 0
dataframe.toJSON
.collect()
.foreach(row => {

// There are a few cases to cover here.
// 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE
// -> send the current batch and reset.
// 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE
// -> append the row to the buffer.
// 3. The row in question is larger than the MAX_BATCH_SIZE
// -> fail the query.

// Case 3. - Fail
if (row.size > MAX_BATCH_SIZE) {
throw SparkException.internalError(
s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}")
}

// Case 1 - FLush and send.
if (sb.size + row.size > MAX_BATCH_SIZE) {
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.JSONBatch
.newBuilder()
.setData(ByteString.copyFromUtf8(sb.toString()))
.setRowCount(rowCount)
.build()
response.setJsonBatch(batch)
responseObserver.onNext(response.build())
sb.clear()
sb.append(row)
rowCount = 1
} else {
// Case 2 - Append.
// Make sure to put the newline delimiters only between items and not at the end.
if (rowCount > 0) {
sb.append("\n")
}
sb.append(row)
rowCount += 1
}
})

// If the last batch is not empty, send out the data to the client.
if (sb.size > 0) {
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.JSONBatch
.newBuilder()
.setData(ByteString.copyFromUtf8(sb.toString()))
.setRowCount(rowCount)
.build()
response.setJsonBatch(batch)
responseObserver.onNext(response.build())
}

responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
responseObserver.onCompleted()
processAsArrowBatches(request.getClientId, dataframe)
}

def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
Expand All @@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
var numSent = 0

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal(
SparkConnectStreamHandler
.rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))

val signal = new Object
val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
var error: Throwable = null

val processPartition = (iter: Iterator[Batch]) => iter.toArray

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}

val future = spark.sparkContext.submitJob(
rdd = batches,
processPartition = processPartition,
partitions = Seq.range(0, numPartitions),
resultHandler = resultHandler,
resultFunc = () => ())

// Collect errors and propagate them to the main thread.
future.onComplete { result =>
result.failed.foreach { throwable =>
signal.synchronized {
error = throwable
signal.notify()
}
}
}(ThreadUtils.sameThread)

// The main thread will wait until 0-th partition is available,
// then send it to client and wait for the next partition.
// Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
// the arrow batches in main thread to avoid DAGScheduler thread been blocked for
// tasks not related to scheduling. This is particularly important if there are
// multiple users or clients running code at the same time.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
var result = partitions.remove(currentPartitionId)
while (result.isEmpty && error == null) {
signal.wait()
result = partitions.remove(currentPartitionId)
}
error match {
case NonFatal(e) =>
responseObserver.onError(error)
logError("Error while processing query.", e)
return
case fatal: Throwable => throw fatal
case null => result.get
}
}

partition.foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
}

currentPartitionId += 1
batches.collect().foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ private class PushBasedFetchHelper(
* 2. There is a failure when fetching remote shuffle chunks.
* 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
* (local or remote).
* 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle chunk
* (local or remote).
*/
def initiateFallbackFetchForPushMergedBlock(
blockId: BlockId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ final class ShuffleBlockFetcherIterator(
logDebug("Number of requests in flight " + reqsInFlight)
}

if (buf.size == 0) {
val in = if (buf.size == 0) {
// We will never legitimately receive a zero-size block. All blocks with zero records
// have zero size and all zero-size blocks have no records (and hence should never
// have been requested in the first place). This statement relies on behaviors of the
Expand All @@ -798,38 +798,52 @@ final class ShuffleBlockFetcherIterator(
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
}

val in = try {
val bufIn = buf.createInputStream()
if (checksumEnabled) {
val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
checkedIn = new CheckedInputStream(bufIn, checksum)
checkedIn
if (blockId.isShuffleChunk) {
// Zero-size block may come from nodes with hardware failures, For shuffle chunks,
// the original shuffle blocks that belong to that zero-size shuffle chunk is
// available and we can opt to fallback immediately.
logWarning(msg)
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
// Set result to null to trigger another iteration of the while loop to get either.
result = null
null
} else {
bufIn
throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
}
} catch {
// The exception could only be throwed by local shuffle block
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
e match {
case ce: ClosedByInterruptException =>
logError("Failed to create input stream from local block, " +
ce.getMessage)
case e: IOException => logError("Failed to create input stream from local block", e)
}
buf.release()
if (blockId.isShuffleChunk) {
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
// Set result to null to trigger another iteration of the while loop to get either.
result = null
null
} else {
try {
val bufIn = buf.createInputStream()
if (checksumEnabled) {
val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
checkedIn = new CheckedInputStream(bufIn, checksum)
checkedIn
} else {
throwFetchFailedException(blockId, mapIndex, address, e)
bufIn
}
} catch {
// The exception could only be throwed by local shuffle block
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
e match {
case ce: ClosedByInterruptException =>
logError("Failed to create input stream from local block, " +
ce.getMessage)
case e: IOException =>
logError("Failed to create input stream from local block", e)
}
buf.release()
if (blockId.isShuffleChunk) {
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
// Set result to null to trigger another iteration of the while loop to get
// either.
result = null
null
} else {
throwFetchFailedException(blockId, mapIndex, address, e)
}
}
}

if (in != null) {
try {
input = streamWrapper(blockId, in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1814,4 +1814,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}

test("SPARK-40872: fallback to original shuffle block when a push-merged shuffle chunk " +
"is zero-size") {
val blockManager = mock(classOf[BlockManager])
val localDirs = Array("local-dir")
val blocksByAddress = prepareForFallbackToLocalBlocks(
blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
val zeroSizeBuffer = createMockManagedBuffer(0)
doReturn(Seq({zeroSizeBuffer})).when(blockManager)
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager), streamWrapperLimitSize = Some(100))
verifyLocalBlocksFromFallback(iterator)
}
}
4 changes: 2 additions & 2 deletions dev/deps/spark-deps-hadoop-2-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,6 @@ xml-apis/1.4.01//xml-apis-1.4.01.jar
xmlenc/0.52//xmlenc-0.52.jar
xz/1.9//xz-1.9.jar
zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar
zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar
zookeeper/3.6.2//zookeeper-3.6.2.jar
zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar
zookeeper/3.6.3//zookeeper-3.6.3.jar
zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar
4 changes: 2 additions & 2 deletions dev/deps/spark-deps-hadoop-3-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,6 @@ wildfly-openssl/1.0.7.Final//wildfly-openssl-1.0.7.Final.jar
xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar
xz/1.9//xz-1.9.jar
zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar
zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar
zookeeper/3.6.2//zookeeper-3.6.2.jar
zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar
zookeeper/3.6.3//zookeeper-3.6.3.jar
zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar
2 changes: 2 additions & 0 deletions docs/sql-ref-syntax-qry-select.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ SELECT [ hints , ... ] [ ALL | DISTINCT ] { [ [ named_expression | regex_column_
Specifies a source of input for the query. It can be one of the following:
* Table relation
* [Join relation](sql-ref-syntax-qry-select-join.html)
* [Pivot relation](sql-ref-syntax-qry-select-pivot.md)
* [Unpivot relation](sql-ref-syntax-qry-select-unpivot.md)
* [Table-value function](sql-ref-syntax-qry-select-tvf.html)
* [Inline table](sql-ref-syntax-qry-select-inline-table.html)
* [ [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html) ] ( Subquery )
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
<hadoop.version>3.3.4</hadoop.version>
<protobuf.version>2.5.0</protobuf.version>
<yarn.version>${hadoop.version}</yarn.version>
<zookeeper.version>3.6.2</zookeeper.version>
<zookeeper.version>3.6.3</zookeeper.version>
<curator.version>2.13.0</curator.version>
<hive.group>org.apache.hive</hive.group>
<hive.classifier>core</hive.classifier>
Expand Down
Loading