diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 77005aa9040b5..c60a2a1706d5a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.nio.file.Files import java.security.PrivilegedExceptionAction import java.text.ParseException @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } + // In client mode, download remote files. + if (deployMode == CLIENT) { + val hadoopConf = new HadoopConfiguration() + args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull + args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull + args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local. // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils { .mkString(",") if (merged == "") null else merged } + + /** + * Download a list of remote files to temp local files. If the file is local, the original file + * will be returned. + * @param fileList A comma separated file list. + * @return A comma separated local files list. + */ + private[deploy] def downloadFileList( + fileList: String, + hadoopConf: HadoopConfiguration): String = { + require(fileList != null, "fileList cannot be null.") + fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + } + + /** + * Download a file from the remote to a local temporary directory. If the input path points to + * a local path, returns it with no operation. + */ + private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + require(path != null, "path cannot be null.") + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "file" | "local" => + path + + case _ => + val fs = FileSystem.get(uri, hadoopConf) + val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") + // scalastyle:on println + fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) + Utils.resolveURI(tmpFile.getAbsolutePath).toString + } + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e878c10183f61..58a181128eb4d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -57,7 +57,8 @@ private[deploy] class DriverRunner( @volatile private[worker] var finalException: Option[Exception] = None // Timeout to wait for when trying to terminate a driver. - private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000 + private val DRIVER_TERMINATE_TIMEOUT_MS = + conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s") // Decoupled for testing def setClock(_clock: Clock): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ee35060926555..bded3a1e4eb54 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -214,11 +214,12 @@ final class ShuffleBlockFetcherIterator( } } - // Shuffle remote blocks to disk when the request is too large. - // TODO: Encryption and compression should be considered. + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val shuffleFiles = blockIds.map { - bId => blockManager.diskBlockManager.createTempLocalBlock()._2 + val shuffleFiles = blockIds.map { _ => + blockManager.diskBlockManager.createTempLocalBlock()._2 }.toArray shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ad39c74a0e232..bbb7999e2a144 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1026,7 +1026,9 @@ private[spark] object Utils extends Logging { ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { - if (!file.delete()) { + if (file.delete()) { + logTrace(s"${file.getAbsolutePath} has been deleted") + } else { // Delete can also fail if the file simply did not exist if (file.exists()) { throw new IOException("Failed to delete: " + file.getAbsolutePath) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a43839a8815f9..6e9721c45931a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts @@ -535,7 +538,7 @@ class SparkSubmitSuite test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars - val files = "hdfs:/file1,file2" // --files + val files = "local:/file1,file2" // --files val archives = "file:/archive1,archive2" // --archives val pyFiles = "py-file1,py-file2" // --py-files @@ -587,7 +590,7 @@ class SparkSubmitSuite test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars - val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles @@ -705,6 +708,87 @@ class SparkSubmitSuite } // scalastyle:on println + private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { + if (sourcePath == outputPath) { + return + } + + val sourceUri = new URI(sourcePath) + val outputUri = new URI(outputPath) + assert(outputUri.getScheme === "file") + + // The path and filename are preserved. + assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(FileUtils.readFileToString(new File(outputUri.getPath)) === + FileUtils.readFileToString(new File(sourceUri.getPath))) + } + + private def deleteTempOutputFile(outputPath: String): Unit = { + val outputFile = new File(new URI(outputPath).getPath) + if (outputFile.exists) { + outputFile.delete() + } + } + + test("downloadFile - invalid url") { + intercept[IOException] { + SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + } + } + + test("downloadFile - file doesn't exist") { + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + intercept[FileNotFoundException] { + SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + } + } + + test("downloadFile does not download local file") { + // empty path is considered as local file. + assert(SparkSubmit.downloadFile("", new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + } + + test("download one file to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePath = s"s3a://${jarFile.getAbsolutePath}" + val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + + test("download list of files to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") + val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + + assert(outputPaths.length === sourcePaths.length) + sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -807,3 +891,10 @@ object UserClasspathFirstTest { } } } + +class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { + override def copyToLocalFile(src: Path, dst: Path): Unit = { + // Ignore the scheme for testing. + super.copyToLocalFile(new Path(src.toUri.getPath), dst) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 1f813a909fb8b..559b3faab8fd2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -420,9 +421,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(localBmId).when(blockManager).blockManagerId val diskBlockManager = mock(classOf[DiskBlockManager]) + val tmpDir = Utils.createTempDir() doReturn{ - var blockId = new TempLocalBlockId(UUID.randomUUID()) - (blockId, new File(blockId.name)) + val blockId = TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(tmpDir, blockId.name)) }.when(diskBlockManager).createTempLocalBlock() doReturn(diskBlockManager).when(blockManager).diskBlockManager @@ -443,34 +445,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) + def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks + // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + maxBytesInFlight = Int.MaxValue, + maxReqsInFlight = Int.MaxValue, + maxReqSizeShuffleToMem = 200, + detectCorrupt = true) + } + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) - // Set maxReqSizeShuffleToMem to be 200. - val iterator1 = new ShuffleBlockFetcherIterator( - TaskContext.empty(), - transfer, - blockManager, - blocksByAddress1, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 200, - true) + fetchShuffleBlock(blocksByAddress1) + // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch + // shuffle block to disk. assert(shuffleFiles === null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) - // Set maxReqSizeShuffleToMem to be 200. - val iterator2 = new ShuffleBlockFetcherIterator( - TaskContext.empty(), - transfer, - blockManager, - blocksByAddress2, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 200, - true) + fetchShuffleBlock(blocksByAddress2) + // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch + // shuffle block to disk. assert(shuffleFiles != null) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7521a7e12432c..a4c7f7a8de223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -276,6 +276,8 @@ object FunctionRegistry { // string functions expression[Ascii]("ascii"), + expression[Chr]("char"), + expression[Chr]("chr"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 754b5c4f74e6a..7b64568c69659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5598a146997ca..aba2f5f81f831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1267,6 +1267,51 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } +/** + * Returns the ASCII character having the binary equivalent to n. + * If n is larger than 256 the result is equivalent to chr(n % 256) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the ASCII character having the binary equivalent to `expr`. If n is larger than 256 the result is equivalent to chr(n % 256)", + extended = """ + Examples: + > SELECT _FUNC_(65); + A + """) +// scalastyle:on line.size.limit +case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(LongType) + + protected override def nullSafeEval(lon: Any): Any = { + val longVal = lon.asInstanceOf[Long] + if (longVal < 0) { + UTF8String.EMPTY_UTF8 + } else if ((longVal & 0xFF) == 0) { + UTF8String.fromString(Character.MIN_VALUE.toString) + } else { + UTF8String.fromString((longVal & 0xFF).toChar.toString) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, lon => { + s""" + if ($lon < 0) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else if (($lon & 0xFF) == 0) { + ${ev.value} = UTF8String.fromString(String.valueOf(Character.MIN_VALUE)); + } else { + char c = (char)($lon & 0xFF); + ${ev.value} = UTF8String.fromString(String.valueOf(c)); + } + """ + }) + } +} + /** * Converts the argument from binary to a base 64 string. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3b4289767ad0c..7eccca2e85649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val plan = testRelation2.select('c).orderBy(Floor('a).asc) val expected = testRelation2.select(c, a) - .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) + .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 8ed7a82b943b6..6af0cde73538b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Ceil(doublePi), 4L, EmptyRow) + checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) + checkEvaluation(Ceil(longLit), longLit, EmptyRow) + checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) + checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) + checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) } test("floor") { @@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Floor(doublePi), 3L, EmptyRow) + checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) + checkEvaluation(Floor(longLit), longLit, EmptyRow) + checkEvaluation(Floor(-doublePi), -4L, EmptyRow) + checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) + checkEvaluation(Floor(-longLit), -longLit, EmptyRow) } test("factorial") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 26978a0482fc7..9ae438d568a90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -263,6 +263,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef")) } + test("string for ascii") { + val a = 'a.long.at(0) + checkEvaluation(Chr(Literal(48L)), "0", create_row("abdef")) + checkEvaluation(Chr(a), "a", create_row(97L)) + checkEvaluation(Chr(a), "a", create_row(97L + 256L)) + checkEvaluation(Chr(a), "", create_row(-9L)) + checkEvaluation(Chr(a), Character.MIN_VALUE.toString, create_row(0L)) + checkEvaluation(Chr(a), Character.MIN_VALUE.toString, create_row(256L)) + checkEvaluation(Chr(a), null, create_row(null)) + checkEvaluation(Chr(a), 149.toChar.toString, create_row(149L)) + checkEvaluation(Chr(Literal.create(null, LongType)), null, create_row("abdef")) + } + test("base64/unbase64 for string") { val a = 'a.string.at(0) val b = 'b.binary.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index d993ea6c6cef9..4b52f3e4c49b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { child.executeBroadcast() } + + // `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(child.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index f7167472b05c6..7e3b86b76a34a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -64,12 +64,9 @@ select cot(-1); select ceiling(0); select ceiling(1); select ceil(1234567890123456); -select ceil(12345678901234567); select ceiling(1234567890123456); -select ceiling(12345678901234567); -- floor select floor(0); select floor(1); select floor(1234567890123456); -select floor(12345678901234567); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fe52005aa91da..28cfb744193ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 45 -- !query 0 @@ -321,7 +321,7 @@ struct -- !query 38 select ceiling(0) -- !query 38 schema -struct +struct -- !query 38 output 0 @@ -329,7 +329,7 @@ struct -- !query 39 select ceiling(1) -- !query 39 schema -struct +struct -- !query 39 output 1 @@ -343,56 +343,32 @@ struct -- !query 41 -select ceil(12345678901234567) +select ceiling(1234567890123456) -- !query 41 schema -struct +struct -- !query 41 output -12345678901234567 +1234567890123456 -- !query 42 -select ceiling(1234567890123456) +select floor(0) -- !query 42 schema -struct +struct -- !query 42 output -1234567890123456 +0 -- !query 43 -select ceiling(12345678901234567) +select floor(1) -- !query 43 schema -struct +struct -- !query 43 output -12345678901234567 - - --- !query 44 -select floor(0) --- !query 44 schema -struct --- !query 44 output -0 - - --- !query 45 -select floor(1) --- !query 45 schema -struct --- !query 45 output 1 --- !query 46 +-- !query 44 select floor(1234567890123456) --- !query 46 schema +-- !query 44 schema struct --- !query 46 output +-- !query 44 output 1234567890123456 - - --- !query 47 -select floor(12345678901234567) --- !query 47 schema -struct --- !query 47 output -12345678901234567 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2f52192b54030..9f691cb10f139 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1855,4 +1855,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) df.filter(filter).count } + + test("SPARK-20897: cached self-join should not fail") { + // force to plan sort merge join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df = Seq(1 -> "a").toDF("i", "j") + val df1 = df.as("t1") + val df2 = df.as("t2") + assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) + } + } }