Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.deploy
import java.io.{File, IOException}
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
import java.nio.file.Files
import java.security.PrivilegedExceptionAction
import java.text.ParseException

Expand All @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties

import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
Expand Down Expand Up @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils {
RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
}

// In client mode, download remote files.
if (deployMode == CLIENT) {
val hadoopConf = new HadoopConfiguration()
args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull
args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull
args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull
args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull
}

// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local.
// In Mesos cluster mode, non-local python files are automatically downloaded by Mesos.
Expand Down Expand Up @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils {
.mkString(",")
if (merged == "") null else merged
}

/**
* Download a list of remote files to temp local files. If the file is local, the original file
* will be returned.
* @param fileList A comma separated file list.
* @return A comma separated local files list.
*/
private[deploy] def downloadFileList(
fileList: String,
hadoopConf: HadoopConfiguration): String = {
require(fileList != null, "fileList cannot be null.")
fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",")
}

/**
* Download a file from the remote to a local temporary directory. If the input path points to
* a local path, returns it with no operation.
*/
private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = {
require(path != null, "path cannot be null.")
val uri = Utils.resolveURI(path)
uri.getScheme match {
case "file" | "local" =>
path

case _ =>
val fs = FileSystem.get(uri, hadoopConf)
val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath)
// scalastyle:off println
printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
// scalastyle:on println
fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
Utils.resolveURI(tmpFile.getAbsolutePath).toString
}
}
}

/** Provides utility functions to be used inside SparkSubmit. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ private[deploy] class DriverRunner(
@volatile private[worker] var finalException: Option[Exception] = None

// Timeout to wait for when trying to terminate a driver.
private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000
private val DRIVER_TERMINATE_TIMEOUT_MS =
conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s")

// Decoupled for testing
def setClock(_clock: Clock): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ final class ShuffleBlockFetcherIterator(
}
}

// Shuffle remote blocks to disk when the request is too large.
// TODO: Encryption and compression should be considered.
// Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) {
val shuffleFiles = blockIds.map {
bId => blockManager.diskBlockManager.createTempLocalBlock()._2
val shuffleFiles = blockIds.map { _ =>
blockManager.diskBlockManager.createTempLocalBlock()._2
}.toArray
shuffleFilesSet ++= shuffleFiles
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,9 @@ private[spark] object Utils extends Logging {
ShutdownHookManager.removeShutdownDeleteDir(file)
}
} finally {
if (!file.delete()) {
if (file.delete()) {
logTrace(s"${file.getAbsolutePath} has been deleted")
} else {
// Delete can also fail if the file simply did not exist
if (file.exists()) {
throw new IOException("Failed to delete: " + file.getAbsolutePath)
Expand Down
95 changes: 93 additions & 2 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.deploy

import java.io._
import java.net.URI
import java.nio.charset.StandardCharsets

import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import com.google.common.io.ByteStreams
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.Timeouts
Expand Down Expand Up @@ -535,7 +538,7 @@ class SparkSubmitSuite

test("resolves command line argument paths correctly") {
val jars = "/jar1,/jar2" // --jars
val files = "hdfs:/file1,file2" // --files
val files = "local:/file1,file2" // --files
val archives = "file:/archive1,archive2" // --archives
val pyFiles = "py-file1,py-file2" // --py-files

Expand Down Expand Up @@ -587,7 +590,7 @@ class SparkSubmitSuite

test("resolves config paths correctly") {
val jars = "/jar1,/jar2" // spark.jars
val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files
val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files
val archives = "file:/archive1,archive2" // spark.yarn.dist.archives
val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles

Expand Down Expand Up @@ -705,6 +708,87 @@ class SparkSubmitSuite
}
// scalastyle:on println

private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = {
if (sourcePath == outputPath) {
return
}

val sourceUri = new URI(sourcePath)
val outputUri = new URI(outputPath)
assert(outputUri.getScheme === "file")

// The path and filename are preserved.
assert(outputUri.getPath.endsWith(sourceUri.getPath))
assert(FileUtils.readFileToString(new File(outputUri.getPath)) ===
FileUtils.readFileToString(new File(sourceUri.getPath)))
}

private def deleteTempOutputFile(outputPath: String): Unit = {
val outputFile = new File(new URI(outputPath).getPath)
if (outputFile.exists) {
outputFile.delete()
}
}

test("downloadFile - invalid url") {
intercept[IOException] {
SparkSubmit.downloadFile("abc:/my/file", new Configuration())
}
}

test("downloadFile - file doesn't exist") {
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
intercept[FileNotFoundException] {
SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf)
}
}

test("downloadFile does not download local file") {
// empty path is considered as local file.
assert(SparkSubmit.downloadFile("", new Configuration()) === "")
assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file")
}

test("download one file to local") {
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf)
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}

test("download list of files to local") {
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",")

assert(outputPaths.length === sourcePaths.length)
sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
private def runSparkSubmit(args: Seq[String]): Unit = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
Expand Down Expand Up @@ -807,3 +891,10 @@ object UserClasspathFirstTest {
}
}
}

class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem {
override def copyToLocalFile(src: Path, dst: Path): Unit = {
// Ignore the scheme for testing.
super.copyToLocalFile(new Path(src.toUri.getPath), dst)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils


class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
Expand Down Expand Up @@ -420,9 +421,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
doReturn(localBmId).when(blockManager).blockManagerId

val diskBlockManager = mock(classOf[DiskBlockManager])
val tmpDir = Utils.createTempDir()
doReturn{
var blockId = new TempLocalBlockId(UUID.randomUUID())
(blockId, new File(blockId.name))
val blockId = TempLocalBlockId(UUID.randomUUID())
(blockId, new File(tmpDir, blockId.name))
}.when(diskBlockManager).createTempLocalBlock()
doReturn(diskBlockManager).when(blockManager).diskBlockManager

Expand All @@ -443,34 +445,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})

def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
// Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
// construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
// are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
new ShuffleBlockFetcherIterator(
TaskContext.empty(),
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
maxBytesInFlight = Int.MaxValue,
maxReqsInFlight = Int.MaxValue,
maxReqSizeShuffleToMem = 200,
detectCorrupt = true)
}

val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
// Set maxReqSizeShuffleToMem to be 200.
val iterator1 = new ShuffleBlockFetcherIterator(
TaskContext.empty(),
transfer,
blockManager,
blocksByAddress1,
(_, in) => in,
Int.MaxValue,
Int.MaxValue,
200,
true)
fetchShuffleBlock(blocksByAddress1)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
// shuffle block to disk.
assert(shuffleFiles === null)

val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
// Set maxReqSizeShuffleToMem to be 200.
val iterator2 = new ShuffleBlockFetcherIterator(
TaskContext.empty(),
transfer,
blockManager,
blocksByAddress2,
(_, in) => in,
Int.MaxValue,
Int.MaxValue,
200,
true)
fetchShuffleBlock(blocksByAddress2)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
// shuffle block to disk.
assert(shuffleFiles != null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ object FunctionRegistry {

// string functions
expression[Ascii]("ascii"),
expression[Chr]("char"),
expression[Chr]("chr"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType))
Seq(TypeCollection(DoubleType, DecimalType, LongType))

protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(precision, scale) =>
case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.ceil()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
Expand Down Expand Up @@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
}

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType))
Seq(TypeCollection(DoubleType, DecimalType, LongType))

protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(precision, scale) =>
case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.floor()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
Expand Down
Loading