diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 804c314ce67a..ad551c4b0f54 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -26,6 +26,7 @@ import javax.ws.rs.core.UriBuilder import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag +import io.grpc.Status import org.apache.commons.io.{FilenameUtils, FileUtils} import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} @@ -125,11 +126,18 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } else { val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath) Files.createDirectories(target.getParent) - // Disallow overwriting non-classfile artifacts + + // Disallow overwriting with modified version if (Files.exists(target)) { - throw new RuntimeException( - s"Duplicate Artifact: $remoteRelativePath. " + + // makes the query idempotent + if (FileUtils.contentEquals(target.toFile, serverLocalStagingPath.toFile)) { + return + } + + throw Status.ALREADY_EXISTS + .withDescription(s"Duplicate Artifact: $remoteRelativePath. " + "Artifacts cannot be overwritten.") + .asRuntimeException() } Files.move(serverLocalStagingPath, target) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index e424331e7617..d9de2a8094d5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.google.common.io.CountingOutputStream +import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto @@ -80,7 +81,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr } override def onError(throwable: Throwable): Unit = { - Utils.deleteRecursively(stagingDir.toFile) + cleanUpStagedArtifacts() responseObserver.onError(throwable) } @@ -114,16 +115,20 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr protected def cleanUpStagedArtifacts(): Unit = Utils.deleteRecursively(stagingDir.toFile) override def onCompleted(): Unit = { - val artifactSummaries = flushStagedArtifacts() - // Add the artifacts to the session and return the summaries to the client. - val builder = proto.AddArtifactsResponse.newBuilder() - artifactSummaries.foreach(summary => builder.addArtifacts(summary)) - // Delete temp dir - cleanUpStagedArtifacts() - - // Send the summaries and close - responseObserver.onNext(builder.build()) - responseObserver.onCompleted() + try { + val artifactSummaries = flushStagedArtifacts() + // Add the artifacts to the session and return the summaries to the client. + val builder = proto.AddArtifactsResponse.newBuilder() + artifactSummaries.foreach(summary => builder.addArtifacts(summary)) + // Delete temp dir + cleanUpStagedArtifacts() + + // Send the summaries and close + responseObserver.onNext(builder.build()) + responseObserver.onCompleted() + } catch { + case e: StatusRuntimeException => onError(e) + } } /** diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index fa3b7d52379c..f34c4f770885 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import java.util.UUID +import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} @@ -150,6 +151,28 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } } + test("Add artifact idempotency") { + val remotePath = Paths.get("pyfiles/abc.zip") + + withTempPath { path => + Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) + artifactManager.addArtifact(remotePath, path.toPath, None) + } + + withTempPath { path => + // subsequent call succeeds + Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) + artifactManager.addArtifact(remotePath, path.toPath, None) + } + + withTempPath { path => + Files.write(path.toPath, "updated file".getBytes(StandardCharsets.UTF_8)) + assertThrows[StatusRuntimeException] { + artifactManager.addArtifact(remotePath, path.toPath, None) + } + } + } + test("SPARK-43790: Forward artifact file to cloud storage path") { val copyDir = Utils.createTempDir().toPath val destFSDir = Utils.createTempDir().toPath