diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala index a5e4cef1ec1a..8b2807a80dd1 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala @@ -42,10 +42,7 @@ private[spark] object SparkThreadUtils { @throws(classOf[SparkException]) def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { try { - // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. - // See SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(atMost)(awaitPermission) + awaitResultNoSparkExceptionConversion(awaitable, atMost) } catch { case e: SparkFatalException => throw e.throwable @@ -56,5 +53,12 @@ private[spark] object SparkThreadUtils { throw new SparkException("Exception thrown in awaitResult: ", t) } } + + def awaitResultNoSparkExceptionConversion[T](awaitable: Awaitable[T], atMost: Duration): T = { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(atMost)(awaitPermission) + } // scalastyle:on awaitresult } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index d14caebe5b81..b0c4564130d3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -451,6 +451,33 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(countAttempted == 7) } + test("ArtifactManager retries errors") { + var attempt = 0 + + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .interceptor(new ClientInterceptor { + override def interceptCall[ReqT, RespT]( + methodDescriptor: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + channel: Channel): ClientCall[ReqT, RespT] = { + attempt += 1; + if (attempt <= 3) { + throw Status.UNAVAILABLE.withDescription("").asRuntimeException() + } + + channel.newCall(methodDescriptor, callOptions) + } + }) + .build() + + val session = SparkSession.builder().client(client).create() + val artifactFilePath = commonResourcePath.resolve("artifact-tests") + session.addArtifact(artifactFilePath.resolve("smallClassFile.class").toString) + } + test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") { startDummyServer(0) client = SparkConnectClient diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 36bc60c7d63a..6eb59bd37574 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -31,10 +31,12 @@ import scala.util.control.NonFatal import Artifact._ import com.google.protobuf.ByteString +import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.commons.lang3.StringUtils +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsResponse import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary @@ -63,6 +65,7 @@ class ArtifactManager( private val CHUNK_SIZE: Int = 32 * 1024 private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder] + private[this] val stubState = stub.stubState /** * Register a [[ClassFinder]] for dynamically generated classes. @@ -228,6 +231,17 @@ class ArtifactManager( return } + try { + stubState.retryHandler.retry { + addArtifactsImpl(artifacts) + } + } catch { + case ex: StatusRuntimeException => + throw new SparkException(ex.toString, ex.getCause) + } + } + + private[client] def addArtifactsImpl(artifacts: Iterable[Artifact]): Unit = { val promise = Promise[Seq[ArtifactSummary]]() val responseHandler = new StreamObserver[proto.AddArtifactsResponse] { private val summaries = mutable.Buffer.empty[ArtifactSummary] @@ -284,7 +298,10 @@ class ArtifactManager( writeBatch() } stream.onCompleted() - SparkThreadUtils.awaitResult(promise.future, Duration.Inf) + // Don't convert to SparkException yet for the sake of retrying. + // retryPolicies are designed around underlying grpc StatusRuntimeException's. + // Convert to sparkException only if retrying fails. + SparkThreadUtils.awaitResultNoSparkExceptionConversion(promise.future, Duration.Inf) // TODO(SPARK-42658): Handle responses containing CRC failures. } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala index 382bc8706955..187c2842a0bc 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala @@ -23,7 +23,7 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse private[client] class CustomSparkConnectStub( channel: ManagedChannel, - stubState: SparkConnectStubState) { + val stubState: SparkConnectStubState) { private val stub = SparkConnectServiceGrpc.newStub(channel)