Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ private[spark] object SparkThreadUtils {
@throws(classOf[SparkException])
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
try {
// `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
// See SPARK-13747.
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
awaitable.result(atMost)(awaitPermission)
awaitResultNoSparkExceptionConversion(awaitable, atMost)
} catch {
case e: SparkFatalException =>
throw e.throwable
Expand All @@ -56,5 +53,12 @@ private[spark] object SparkThreadUtils {
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}

def awaitResultNoSparkExceptionConversion[T](awaitable: Awaitable[T], atMost: Duration): T = {
// `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
// See SPARK-13747.
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
awaitable.result(atMost)(awaitPermission)
}
// scalastyle:on awaitresult
}
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,33 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(countAttempted == 7)
}

test("ArtifactManager retries errors") {
var attempt = 0

startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
attempt += 1;
if (attempt <= 3) {
throw Status.UNAVAILABLE.withDescription("").asRuntimeException()
}

channel.newCall(methodDescriptor, callOptions)
}
})
.build()

val session = SparkSession.builder().client(client).create()
val artifactFilePath = commonResourcePath.resolve("artifact-tests")
session.addArtifact(artifactFilePath.resolve("smallClassFile.class").toString)
}

test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") {
startDummyServer(0)
client = SparkConnectClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ import scala.util.control.NonFatal

import Artifact._
import com.google.protobuf.ByteString
import io.grpc.StatusRuntimeException
import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex
import org.apache.commons.lang3.StringUtils

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
Expand Down Expand Up @@ -63,6 +65,7 @@ class ArtifactManager(
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]
private[this] val stubState = stub.stubState

/**
* Register a [[ClassFinder]] for dynamically generated classes.
Expand Down Expand Up @@ -228,6 +231,17 @@ class ArtifactManager(
return
}

try {
stubState.retryHandler.retry {
addArtifactsImpl(artifacts)
}
} catch {
case ex: StatusRuntimeException =>
throw new SparkException(ex.toString, ex.getCause)
}
}

private[client] def addArtifactsImpl(artifacts: Iterable[Artifact]): Unit = {
val promise = Promise[Seq[ArtifactSummary]]()
val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
private val summaries = mutable.Buffer.empty[ArtifactSummary]
Expand Down Expand Up @@ -284,7 +298,10 @@ class ArtifactManager(
writeBatch()
}
stream.onCompleted()
SparkThreadUtils.awaitResult(promise.future, Duration.Inf)
// Don't convert to SparkException yet for the sake of retrying.
// retryPolicies are designed around underlying grpc StatusRuntimeException's.
// Convert to sparkException only if retrying fails.
SparkThreadUtils.awaitResultNoSparkExceptionConversion(promise.future, Duration.Inf)
// TODO(SPARK-42658): Handle responses containing CRC failures.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse

private[client] class CustomSparkConnectStub(
channel: ManagedChannel,
stubState: SparkConnectStubState) {
val stubState: SparkConnectStubState) {

private val stub = SparkConnectServiceGrpc.newStub(channel)

Expand Down