Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import javax.ws.rs.core.UriBuilder
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import io.grpc.Status
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

Expand Down Expand Up @@ -125,11 +126,18 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
} else {
val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath)
Files.createDirectories(target.getParent)
// Disallow overwriting non-classfile artifacts

// Disallow overwriting with modified version
if (Files.exists(target)) {
throw new RuntimeException(
s"Duplicate Artifact: $remoteRelativePath. " +
// makes the query idempotent
if (FileUtils.contentEquals(target.toFile, serverLocalStagingPath.toFile)) {
return
}

throw Status.ALREADY_EXISTS
.withDescription(s"Duplicate Artifact: $remoteRelativePath. " +
"Artifacts cannot be overwritten.")
.asRuntimeException()
}
Files.move(serverLocalStagingPath, target)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.control.NonFatal

import com.google.common.io.CountingOutputStream
import io.grpc.StatusRuntimeException
import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
Expand Down Expand Up @@ -80,7 +81,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
}

override def onError(throwable: Throwable): Unit = {
Utils.deleteRecursively(stagingDir.toFile)
cleanUpStagedArtifacts()
responseObserver.onError(throwable)
}

Expand Down Expand Up @@ -114,16 +115,20 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
protected def cleanUpStagedArtifacts(): Unit = Utils.deleteRecursively(stagingDir.toFile)

override def onCompleted(): Unit = {
val artifactSummaries = flushStagedArtifacts()
// Add the artifacts to the session and return the summaries to the client.
val builder = proto.AddArtifactsResponse.newBuilder()
artifactSummaries.foreach(summary => builder.addArtifacts(summary))
// Delete temp dir
cleanUpStagedArtifacts()

// Send the summaries and close
responseObserver.onNext(builder.build())
responseObserver.onCompleted()
try {
val artifactSummaries = flushStagedArtifacts()
// Add the artifacts to the session and return the summaries to the client.
val builder = proto.AddArtifactsResponse.newBuilder()
artifactSummaries.foreach(summary => builder.addArtifacts(summary))
// Delete temp dir
cleanUpStagedArtifacts()

// Send the summaries and close
responseObserver.onNext(builder.build())
responseObserver.onCompleted()
} catch {
case e: StatusRuntimeException => onError(e)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import java.util.UUID

import io.grpc.StatusRuntimeException
import org.apache.commons.io.FileUtils

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
Expand Down Expand Up @@ -150,6 +151,28 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
}
}

test("Add artifact idempotency") {
val remotePath = Paths.get("pyfiles/abc.zip")

withTempPath { path =>
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
artifactManager.addArtifact(remotePath, path.toPath, None)
}

withTempPath { path =>
// subsequent call succeeds
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
artifactManager.addArtifact(remotePath, path.toPath, None)
}

withTempPath { path =>
Files.write(path.toPath, "updated file".getBytes(StandardCharsets.UTF_8))
assertThrows[StatusRuntimeException] {
artifactManager.addArtifact(remotePath, path.toPath, None)
}
}
}

test("SPARK-43790: Forward artifact file to cloud storage path") {
val copyDir = Utils.createTempDir().toPath
val destFSDir = Utils.createTempDir().toPath
Expand Down