diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5f9d11475c94b..2fd2d3675661f 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1757,7 +1757,8 @@ setMethod("toRadians", #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd'. +#' By default, it follows casting rules to a DateType if the format is omitted +#' (equivalent to \code{cast(df$x, "date")}). #' #' @param x Column to parse. #' @param format string to use to parse x Column to DateType. (optional) @@ -1832,10 +1833,11 @@ setMethod("to_json", signature(x = "Column"), #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd HH:mm:ss'. +#' By default, it follows casting rules to a TimestampType if the format is omitted +#' (equivalent to \code{cast(df$x, "timestamp")}). #' #' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) +#' @param format string to use to parse x Column to TimestampType. (optional) #' #' @rdname to_timestamp #' @name to_timestamp diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index c9615c8d4faf6..e2241e03b55f8 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 4bc935c79eb0f..ac706261999fb 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index 518fb7bd94043..6e160fae1afed 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { skip_on_cran() diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 63f54e1af02b1..00954fa31b0ee 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,7 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 25bb2b84266dd..236cb3885445e 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 504ded4fc8623..254f8f522a708 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 632a90d68177f..f6d9f5423df02 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -60,7 +60,7 @@ test_that("repeatedly starting and stopping SparkR", { skip_on_cran() for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -69,7 +69,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -79,12 +79,12 @@ test_that("repeatedly starting and stopping SparkSession", { test_that("rdd GC across sparkR.stop", { skip_on_cran() - sc <- sparkR.sparkContext() # sc should get id 0 + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -104,7 +104,7 @@ test_that("rdd GC across sparkR.stop", { test_that("job group functions can be called", { skip_on_cran() - sc <- sparkR.sparkContext() + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() @@ -118,7 +118,7 @@ test_that("job group functions can be called", { test_that("utility function can be called", { skip_on_cran() - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) @@ -175,7 +175,7 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() @@ -184,7 +184,7 @@ test_that("spark.lapply should perform simple transforms", { test_that("add and get file to be downloaded with Spark job on every node", { skip_on_cran() - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index f823ad8e9c985..d7d9eeed1575e 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R index 7348c893d0af3..8b3b4f73de170 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index cbc7087182868..f3eaeb381afc4 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib classification algorithms, except for tree-based algorithms") # Tests for MLlib classification algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 478012e8828cd..df8e5968b27f4 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib clustering algorithms") # Tests for MLlib clustering algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R index c38f1133897dd..1fa5375f9da31 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib frequent pattern mining") # Tests for MLlib frequent pattern mining algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.fpGrowth", { data <- selectExpr(createDataFrame(data.frame(items = c( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R index 6b1040db93050..e3e2b15c71361 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib recommendation algorithms") # Tests for MLlib recommendation algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 58924f952c6bf..44c98be906d81 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib regression algorithms, except for tree-based algorithms") # Tests for MLlib regression algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { skip_on_cran() diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/inst/tests/testthat/test_mllib_stat.R index beb148e7702fd..1600833a5d03a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_stat.R +++ b/R/pkg/inst/tests/testthat/test_mllib_stat.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib statistics algorithms") # Tests for MLlib statistics algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.kstest", { data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e0802a9b02d13..146bc2878e263 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib tree-based algorithms") # Tests for MLlib tree-based algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 1f7f387de08ce..52d4c93ed9599 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index a3b1631e1d119..fb244e1d49e20 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index cedf4f100c6c4..18320ea44b389 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 19aa61e9a56c3..0ff2e02e75a98 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -61,7 +61,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session() +sparkSession <- sparkR.session(master = sparkRTestMaster) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 91df7ac6f9849..b20b4312fbaae 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -21,7 +21,7 @@ context("Structured Streaming") # Tests for Structured Streaming functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") if (.Platform$OS.type == "windows") { diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index e2130eaac78dd..c00723ba31f4c 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 28b7e8e3183fd..e8a961cb3e870 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,7 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 4a01e875405ff..2fc6530d63e54 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,7 +18,7 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 29812f872c784..9c6cba535d118 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -31,4 +31,9 @@ sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 49f4ab8f146a8..13a399165c8b4 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,8 +46,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() +sparkR.session(master = "local[1]") ``` -```{r, message=FALSE, results="hide"} +```{r, eval=FALSE} sparkR.session() ``` diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c96aa0a..01d8973e1bb06 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c337b992c840..7767ef1803a06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab68e..51feccfb8342a 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b22da565d86e7..992d3396d203f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.addTaskCompletionListener(_ => throw new Exception("blah")) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("immediately call a completion listener if the context is completed") { var invocations = 0 val context = TaskContext.empty() - context.markTaskCompleted() + context.markTaskCompleted(None) context.addTaskCompletionListener(_ => invocations += 1) assert(invocations == 1) - context.markTaskCompleted() + context.markTaskCompleted(None) assert(invocations == 1) } @@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(lastError == error) assert(invocations == 1) } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener1")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bbfd6df3b6990..7859b0bba2b48 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import scala.language.reflectiveCalls - import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index 3050f9a250235..535105379963a 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e56e440380a54..9900d1edc4cb0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 83eba3690e289..df3483830ca9c 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.util.random -import scala.language.reflectiveCalls - import org.apache.commons.math3.stat.inference.ChiSquareTest import org.scalatest.Matchers @@ -27,26 +25,22 @@ import org.apache.spark.util.Utils.times class XORShiftRandomSuite extends SparkFunSuite with Matchers { - private def fixture = new { - val seed = 1L - val xorRand = new XORShiftRandom(seed) - val hundMil = 1e8.toInt - } - /* * This test is based on a chi-squared test for randomness. */ test ("XORShift generates valid random numbers") { - val f = fixture + val xorRand = new XORShiftRandom(1L) val numBins = 10 // create 10 bins val numRows = 5 // create 5 rows val bins = Array.ofDim[Long](numRows, numBins) // populate bins based on modulus of the random number for each row - for (r <- 0 to numRows-1) { - times(f.hundMil) {bins(r)(math.abs(f.xorRand.nextInt) % numBins) += 1} + for (r <- 0 until numRows) { + times(100000000) { + bins(r)(math.abs(xorRand.nextInt) % numBins) += 1 + } } /* diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index f736ceed4436f..b03701e4915d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -21,7 +21,6 @@ package org.apache.spark.examples.ml import java.util.Locale import scala.collection.mutable -import scala.language.reflectiveCalls import scopt.OptionParser diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index ed598d0d7dfae..3bd8ff54c2238 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -21,7 +21,6 @@ package org.apache.spark.examples.ml import java.util.Locale import scala.collection.mutable -import scala.language.reflectiveCalls import scopt.OptionParser diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index 31ba18033519a..6903a1c298ced 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -18,8 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import scala.language.reflectiveCalls - import scopt.OptionParser import org.apache.spark.examples.mllib.AbstractParams diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index c67b53899ce4a..bd6cc8cff2348 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.ml import scala.collection.mutable -import scala.language.reflectiveCalls import scopt.OptionParser diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 8fd46c37e2987..a735c218c0d26 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -21,7 +21,6 @@ package org.apache.spark.examples.ml import java.util.Locale import scala.collection.mutable -import scala.language.reflectiveCalls import scopt.OptionParser diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 99321bcc7cf98..b2dc4fcb61964 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -59,6 +59,29 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha @Since("1.6.0") def getHandleInvalid: String = $(handleInvalid) + /** + * Param for how to order labels of string column. The first label after ordering is assigned + * an index of 0. + * Options are: + * - 'frequencyDesc': descending order by label frequency (most frequent label assigned 0) + * - 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0) + * - 'alphabetDesc': descending alphabetical order + * - 'alphabetAsc': ascending alphabetical order + * Default is 'frequencyDesc'. + * + * @group param + */ + @Since("2.3.0") + final val stringOrderType: Param[String] = new Param(this, "stringOrderType", + "how to order labels of string column. " + + "The first label after ordering is assigned an index of 0. " + + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", + ParamValidators.inArray(StringIndexer.supportedStringOrderType)) + + /** @group getParam */ + @Since("2.3.0") + def getStringOrderType: String = $(stringOrderType) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) @@ -79,8 +102,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. - * The indices are in [0, numLabels), ordered by label frequencies. - * So the most frequent label gets index 0. + * The indices are in [0, numLabels). By default, this is ordered by label frequencies + * so the most frequent label gets index 0. The ordering behavior is controlled by + * setting `stringOrderType`. * * @see `IndexToString` for the inverse transformation */ @@ -96,6 +120,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setStringOrderType(value: String): this.type = set(stringOrderType, value) + setDefault(stringOrderType, StringIndexer.frequencyDesc) + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -107,11 +136,17 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) - .rdd - .map(_.getString(0)) - .countByValue() - val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val values = dataset.na.drop(Array($(inputCol))) + .select(col($(inputCol)).cast(StringType)) + .rdd.map(_.getString(0)) + val labels = $(stringOrderType) match { + case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) + .map(_._1).toArray + case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) + .map(_._1).toArray + case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) + case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + } copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -131,6 +166,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + private[feature] val frequencyDesc: String = "frequencyDesc" + private[feature] val frequencyAsc: String = "frequencyAsc" + private[feature] val alphabetDesc: String = "alphabetDesc" + private[feature] val alphabetAsc: String = "alphabetAsc" + private[feature] val supportedStringOrderType: Array[String] = + Array(frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5634d4210f478..806a92760c8b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -291,4 +291,27 @@ class StringIndexerSuite NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") } + + test("StringIndexer order types") { + val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b")) + val df = data.toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + + var idx = 0 + for (orderType <- StringIndexer.supportedStringOrderType) { + val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) + val output = transformed.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + assert(output === expected(idx)) + idx += 1 + } + } } diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8b3487c3f1083..d9b86aff63fa0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -144,12 +144,6 @@ def _(): 'measured in radians.', } -_functions_2_2 = { - 'to_date': 'Converts a string date into a DateType using the (optionally) specified format.', - 'to_timestamp': 'Converts a string timestamp into a timestamp type using the ' + - '(optionally) specified format.', -} - # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + @@ -987,9 +981,10 @@ def months_between(date1, date2): def to_date(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` - using the optionally specified format. Default format is 'yyyy-MM-dd'. - Specify formats according to + using the optionally specified format. Specify formats according to `SimpleDateFormats `_. + By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format + is omitted (equivalent to ``col.cast("date")``). >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_date(df.t).alias('date')).collect() @@ -1011,9 +1006,10 @@ def to_date(col, format=None): def to_timestamp(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` - using the optionally specified format. Default format is 'yyyy-MM-dd HH:mm:ss'. Specify - formats according to + using the optionally specified format. Specify formats according to `SimpleDateFormats `_. + By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format + is omitted (equivalent to ``col.cast("timestamp")``). >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_timestamp(df.t).alias('dt')).collect() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 87075003e5516..acea9113ee858 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2524,7 +2524,7 @@ def test_datetime_functions(self): from datetime import date, datetime df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() - self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)']) + self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") def test_unbounded_frames(self): diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ed5450b494ccd..f99ce244bf436 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -334,7 +334,7 @@ queryOrganization (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? (SORT BY sort+=sortItem (',' sort+=sortItem)*)? windows? - (LIMIT limit=expression)? + (LIMIT (ALL | limit=expression))? ; multiInsertQueryBody @@ -549,7 +549,7 @@ valueExpression : primaryExpression #valueExpressionDefault | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary - | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary @@ -590,7 +590,7 @@ comparisonOperator ; arithmeticOperator - : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | HAT + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT ; predicateOperator @@ -869,6 +869,7 @@ DIV: 'DIV'; TILDE: '~'; AMPERSAND: '&'; PIPE: '|'; +CONCAT_PIPE: '||'; HAT: '^'; PERCENTLIT: 'PERCENT'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 82710a2a183ab..6d1d019cc4743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -836,8 +836,16 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe - constructParams(tpe).map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val params = constructParams(tpe) + // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) + if (actualTypeArgs.nonEmpty) { + params.map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } else { + params.map { p => + p.name.toString -> p.typeSignature + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c5c2a5b236672..7538a6477f495 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1003,18 +1003,32 @@ class Analyzer( */ object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + // This is a strict check though, we put this to apply the rule only if the expression is not + // resolvable by child. + private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { + !child.output.exists(a => resolver(a.name, attrName)) + } + + private def mayResolveAttrByAggregateExprs( + exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { + exprs.map { _.transform { + case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + }} + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => + agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) + + case gs @ GroupingSets(selectedGroups, groups, child, aggs) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedAttribute]) => - // This is a strict check though, we put this to apply the rule only in alias expressions - def notResolvableByChild(attrName: String): Boolean = - !child.output.exists(a => resolver(a.name, attrName)) - agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute if notResolvableByChild(u.name) => - aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) - case e => e - }) + gs.copy( + selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), + groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 18e514681e811..f6653d384fe1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -73,7 +73,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser, + new CatalystSqlParser(conf), DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index a98cd33f2780c..de4c94d12abdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1146,44 +1146,21 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } /** - * Returns the date part of a timestamp or string. + * Parses a column to a date based on the given format. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Extracts the date part of the date or timestamp expression `expr`.", + usage = """ + _FUNC_(date_str[, fmt]) - Parses the `date_str` expression with the `fmt` expression to + a date. Returns null with invalid input. By default, it follows casting rules to a date if + the `fmt` is omitted. + """, extended = """ Examples: > SELECT _FUNC_('2009-07-30 04:17:52'); 2009-07-30 - """) -case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - // Implicit casting of spark will accept string in both date and timestamp format, as - // well as TimestampType. - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = DateType - - override def eval(input: InternalRow): Any = child.eval(input) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, d => d) - } - - override def prettyName: String = "to_date" -} - -/** - * Parses a column to a date based on the given format. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(date_str, fmt) - Parses the `left` expression with the `fmt` expression. Returns null with invalid input.", - extended = """ - Examples: > SELECT _FUNC_('2016-12-31', 'yyyy-MM-dd'); 2016-12-31 """) -// scalastyle:on line.size.limit case class ParseToDate(left: Expression, format: Option[Expression], child: Expression) extends RuntimeReplaceable { @@ -1194,13 +1171,13 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr def this(left: Expression) = { // backwards compatability - this(left, Option(null), ToDate(left)) + this(left, None, Cast(left, DateType)) } override def flatArguments: Iterator[Any] = Iterator(left, format) override def sql: String = { if (format.isDefined) { - s"$prettyName(${left.sql}, ${format.get.sql}" + s"$prettyName(${left.sql}, ${format.get.sql})" } else { s"$prettyName(${left.sql})" } @@ -1212,24 +1189,36 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr /** * Parses a column to a timestamp based on the supplied format. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp, fmt) - Parses the `left` expression with the `format` expression to a timestamp. Returns null with invalid input.", + usage = """ + _FUNC_(timestamp[, fmt]) - Parses the `timestamp` expression with the `fmt` expression to + a timestamp. Returns null with invalid input. By default, it follows casting rules to + a timestamp if the `fmt` is omitted. + """, extended = """ Examples: + > SELECT _FUNC_('2016-12-31 00:12:00'); + 2016-12-31 00:12:00 > SELECT _FUNC_('2016-12-31', 'yyyy-MM-dd'); - 2016-12-31 00:00:00.0 + 2016-12-31 00:00:00 """) -// scalastyle:on line.size.limit -case class ParseToTimestamp(left: Expression, format: Expression, child: Expression) +case class ParseToTimestamp(left: Expression, format: Option[Expression], child: Expression) extends RuntimeReplaceable { def this(left: Expression, format: Expression) = { - this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) + this(left, Option(format), Cast(UnixTimestamp(left, format), TimestampType)) } + def this(left: Expression) = this(left, None, Cast(left, TimestampType)) + override def flatArguments: Iterator[Any] = Iterator(left, format) - override def sql: String = s"$prettyName(${left.sql}, ${format.sql})" + override def sql: String = { + if (format.isDefined) { + s"$prettyName(${left.sql}, ${format.get.sql})" + } else { + s"$prettyName(${left.sql})" + } + } override def prettyName: String = "to_timestamp" override def dataType: DataType = TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c4d47ab2084fd..de1a46dc47805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1023,10 +1023,10 @@ abstract class RoundBase(child: Expression, scale: Expression, // not overriding since _scale is a constant int at runtime def nullSafeEval(input1: Any): Any = { - child.dataType match { - case _: DecimalType => + dataType match { + case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, _scale, mode).orNull + decimal.toPrecision(decimal.precision, s, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1055,10 +1055,10 @@ abstract class RoundBase(child: Expression, scale: Expression, override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) - val evaluationCode = child.dataType match { - case _: DecimalType => + val evaluationCode = dataType match { + case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3fa84589e3c68..aa5a1b5448c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression escape character, the following character is matched literally. It is invalid to escape any other character. + Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order + to match "\abc", the pattern should be "\\abc". + + When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks + to Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the pattern to match "\abc" should be "\abc". + Examples: > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' true @@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } @ExpressionDescription( - usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") + usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.", + extended = """ + Arguments: + str - a string expression + regexp - a string expression. The pattern string should be a Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. + For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback + to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the `regexp` that can match "\abc" is "^\abc$". + + Examples: + When spark.sql.parser.escapedStringLiterals is disabled (default). + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*' + true + + When spark.sql.parser.escapedStringLiterals is enabled. + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*' + true + + See also: + Use LIKE to match with simple string pattern. +""") case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 046ea65d454a1..740422bfc7a42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ + def this() = this(new SQLConf()) + protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -278,6 +281,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val withWindow = withOrder.optionalMap(windows)(withWindows) // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause withWindow.optional(limit) { Limit(typedVisit(limit), withWindow) } @@ -1007,6 +1011,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Add(left, right) case SqlBaseParser.MINUS => Subtract(left, right) + case SqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) case SqlBaseParser.AMPERSAND => BitwiseAnd(left, right) case SqlBaseParser.HAT => @@ -1423,7 +1429,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - ctx.STRING().asScala.map(string).mkString + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index dcccbd0ed8d6b..8e2e973485e1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new AstBuilder(conf) +} + +/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder + val astBuilder = new AstBuilder(new SQLConf()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 6fbc33fad735c..77fdaa8255aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -68,6 +68,12 @@ object ParserUtils { /** Convert a string node into a string. */ def string(node: TerminalNode): String = unescapeSQLString(node.getText) + /** Convert a string node into a string without unescaping. */ + def stringWithoutUnescape(node: TerminalNode): String = { + // STRING parser rule forces that the input always has quotes at the starting and ending. + node.getText.slice(1, node.getText.size - 1) + } + /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { val opt = Option(token) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2fb65bd435507..51faa333307b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -423,7 +423,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } -object QueryPlan { +object QueryPlan extends PredicateHelper { /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we @@ -442,4 +442,17 @@ object QueryPlan { } }.canonicalized.asInstanceOf[T] } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExprId(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f663d7b8a8f7b..2c19265bedc5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -704,7 +704,7 @@ case class Expand( * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer * * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should - * exists in groupByExprs. + * exist in groupByExprs. * @param groupByExprs The Group By expressions candidates. * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6c1592fd8881d..0a3bb514ad214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -427,7 +427,7 @@ object DateTimeUtils { * The return type is [[Option]] in order to distinguish between 0 and null. The following * formats are allowed: * - * `yyyy`, + * `yyyy` * `yyyy-[m]m` * `yyyy-[m]m-[d]d` * `yyyy-[m]m-[d]d ` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b24419a41edb0..b97adf7221d18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -196,6 +196,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") + .internal() + .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + + "parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " + + "prior to Spark 2.0.") + .booleanConf + .createWithDefault(false) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -917,6 +925,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index ca89bf7db0b4f..d3bac0a4d2773 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -495,14 +495,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } - test("function to_date") { - checkEvaluation( - ToDate(Literal(Date.valueOf("2015-07-22"))), - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) - checkEvaluation(ToDate(Literal.create(null, DateType)), null) - checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType) - } - test("function trunc") { def testTrunc(input: Date, fmt: String, expected: Date): Unit = { checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6b5bfac94645c..1555dd1cf58d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), BigDecimal(3.141593), BigDecimal(3.1415927)) - // round_scale > current_scale would result in precision increase - // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => - checkEvaluation(Round(bdPi, scale), null, EmptyRow) - checkEvaluation(BRound(bdPi, scale), null, EmptyRow) + checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index eb68eb9851b85..8bc2010cabece 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest { import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) + val defaultParser = CatalystSqlParser + + def assertEqual( + sqlCommand: String, + e: Expression, + parser: ParserInterface = defaultParser): Unit = { + compareExpressions(parser.parseExpression(sqlCommand), e) } def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) + val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } @@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest { test("long binary logical expressions") { def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) + val e = defaultParser.parseExpression(sql) assert(e.collect { case _: EqualTo => true }.size === 1000) assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) } @@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) } + test("like expressions with ESCAPED_STRING_LITERALS = true") { + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") + val parser = new CatalystSqlParser(conf) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + } + test("is null expressions") { assertEqual("a is null", 'a.isNull) assertEqual("a is not null", 'a.isNotNull) @@ -418,38 +433,79 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") + Seq(true, false).foreach { escape => + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) + val parser = new CatalystSqlParser(conf) + + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00') + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + assertEqual("'\"'", "\"", parser) // Double quote + assertEqual("'\b'", "\b", parser) // Backspace + assertEqual("'\n'", "\n", parser) // Newline + assertEqual("'\r'", "\r", parser) // Carriage return + assertEqual("'\t'", "\t", parser) // Tab character + + // Octals + assertEqual("'\110\145\154\154\157\041'", "Hello!", parser) + // Unicode + assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) + } + + } } test("intervals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 866fa98533218..74fc23a52a141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -519,8 +519,8 @@ case class FileSourceScanExec( relation, output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(QueryPlan.normalizeExprId(_, output)), - dataFilters.map(QueryPlan.normalizeExprId(_, output)), + QueryPlan.normalizePredicates(partitionFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 20dacf88504f1..c2c52894860b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index e42df5dd61c70..5e79232a2043b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -120,7 +120,7 @@ case class FlatMapGroupsWithStateExec( val filteredIter = watermarkPredicateForData match { case Some(predicate) if timeoutConf == EventTimeTimeout => iter.filter(row => !predicate.eval(row)) - case None => + case _ => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b9769f781b237..5edf03666ac22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2683,13 +2683,12 @@ object functions { def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** - * Convert time string to a Unix timestamp (in seconds). - * Uses the pattern "yyyy-MM-dd HH:mm:ss" and will return null on failure. + * Convert time string to a Unix timestamp (in seconds) by casting rules to `TimestampType`. * @group datetime_funcs * @since 2.2.0 */ def to_timestamp(s: Column): Column = withExpr { - new ParseToTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + new ParseToTimestamp(s.expr) } /** @@ -2704,15 +2703,15 @@ object functions { } /** - * Converts the column into DateType. + * Converts the column into `DateType` by casting rules to `DateType`. * * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = withExpr { ToDate(e.expr) } + def to_date(e: Column): Column = withExpr { new ParseToDate(e.expr) } /** - * Converts the column into a DateType with a specified format + * Converts the column into a `DateType` with a specified format * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) * return null if fail. * diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql deleted file mode 100644 index f62b10ca0037b..0000000000000 --- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql +++ /dev/null @@ -1,34 +0,0 @@ - --- unary minus and plus -select -100; -select +230; -select -5.2; -select +6.8e0; -select -key, +key from testdata where key = 2; -select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; -select -max(key), +max(key) from testdata; -select - (-10); -select + (-key) from testdata where key = 32; -select - (+max(key)) from testdata; -select - - 3; -select - + 20; -select + + 100; -select - - max(key) from testdata; -select + - key from testdata where key = 33; - --- div -select 5 / 2; -select 5 / 0; -select 5 / null; -select null / 5; -select 5 div 2; -select 5 div 0; -select 5 div null; -select null div 5; - --- other arithmetics -select 1 + 2; -select 1 - 2; -select 2 * 5; -select 5 % 3; -select pmod(-7, 3); diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 3fd1c37e71795..e957f693a983f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -2,3 +2,7 @@ -- [SPARK-16836] current_date and current_timestamp literals select current_date = current_date(), current_timestamp = current_timestamp(); + +select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd'); + +select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index f8135389a9e5a..8aff4cb524199 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; + +-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b); +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index 2ea35f7f3a5c8..f21912a042716 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,23 +1,27 @@ -- limit on various data types -select * from testdata limit 2; -select * from arraydata limit 2; -select * from mapdata limit 2; +SELECT * FROM testdata LIMIT 2; +SELECT * FROM arraydata LIMIT 2; +SELECT * FROM mapdata LIMIT 2; -- foldable non-literal in limit -select * from testdata limit 2 + 1; +SELECT * FROM testdata LIMIT 2 + 1; -select * from testdata limit CAST(1 AS int); +SELECT * FROM testdata LIMIT CAST(1 AS int); -- limit must be non-negative -select * from testdata limit -1; +SELECT * FROM testdata LIMIT -1; +SELECT * FROM testData TABLESAMPLE (-1 ROWS); -- limit must be foldable -select * from testdata limit key > 3; +SELECT * FROM testdata LIMIT key > 3; -- limit must be integer -select * from testdata limit true; -select * from testdata limit 'a'; +SELECT * FROM testdata LIMIT true; +SELECT * FROM testdata LIMIT 'a'; -- limit within a subquery -select * from (select * from range(10) limit 5) where id > 3; +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3; + +-- limit ALL +SELECT * FROM testdata WHERE key < 3 LIMIT ALL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql new file mode 100644 index 0000000000000..6339d69ca6473 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -0,0 +1,55 @@ + +-- unary minus and plus +select -100; +select +230; +select -5.2; +select +6.8e0; +select -key, +key from testdata where key = 2; +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; +select -max(key), +max(key) from testdata; +select - (-10); +select + (-key) from testdata where key = 32; +select - (+max(key)) from testdata; +select - - 3; +select - + 20; +select + + 100; +select - - max(key) from testdata; +select + - key from testdata where key = 33; + +-- div +select 5 / 2; +select 5 / 0; +select 5 / null; +select null / 5; +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + +-- other arithmetics +select 1 + 2; +select 1 - 2; +select 2 * 5; +select 5 % 3; +select pmod(-7, 3); + +-- check operator precedence. +-- We follow Oracle operator precedence in the table below that lists the levels of precedence +-- among SQL operators from high to low: +------------------------------------------------------------------------------------------ +-- Operator Operation +------------------------------------------------------------------------------------------ +-- +, - identity, negation +-- *, / multiplication, division +-- +, -, || addition, subtraction, concatenation +-- =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison +-- NOT exponentiation, logical negation +-- AND conjunction +-- OR disjunction +------------------------------------------------------------------------------------------ +explain select 'a' || 1 + 2; +explain select 1 - 2 || 'b'; +explain select 2 * 4 + 3 || 'b'; +explain select 3 + 1 || 'a' || 4 / 2; +explain select 1 == 1 OR 'a' || 'b' == 'ab'; +explain select 'a' || 'c' == 'ac' AND 2 == 3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index f21981ef7b72a..7005cafe35cab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -1,3 +1,6 @@ -- Argument number exception select concat_ws(); select format_string(); + +-- A pipe operator for string concatenation +select 'a' || 'b' || 'c'; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 032e4258500fb..13e1e48b038ad 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 1 +-- Number of queries: 3 -- !query 0 @@ -8,3 +8,19 @@ select current_date = current_date(), current_timestamp = current_timestamp() struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean> -- !query 0 output true true + + +-- !query 1 +select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') +-- !query 1 schema +struct +-- !query 1 output +NULL 2016-12-31 2016-12-31 + + +-- !query 2 +select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd') +-- !query 2 schema +struct +-- !query 2 output +NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 825e8f5488c8b..ce7a16a4d0c81 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -328,3 +328,50 @@ struct<> -- !query 25 output org.apache.spark.sql.AnalysisException grouping__id is deprecated; use grouping_id() instead; + + +-- !query 26 +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) +-- !query 26 schema +struct +-- !query 26 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL 1 3 +NULL 2 0 +NULL NULL 3 + + +-- !query 27 +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) +-- !query 27 schema +struct +-- !query 27 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL NULL 3 + + +-- !query 28 +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) +-- !query 28 schema +struct<(a + b):int,k:int,sum((a - b)):bigint> +-- !query 28 output +NULL 1 3 +NULL 2 0 diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index cb4e4d04810d0..146abe6cbd058 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,9 +1,9 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 12 -- !query 0 -select * from testdata limit 2 +SELECT * FROM testdata LIMIT 2 -- !query 0 schema struct -- !query 0 output @@ -12,7 +12,7 @@ struct -- !query 1 -select * from arraydata limit 2 +SELECT * FROM arraydata LIMIT 2 -- !query 1 schema struct,nestedarraycol:array>> -- !query 1 output @@ -21,7 +21,7 @@ struct,nestedarraycol:array>> -- !query 2 -select * from mapdata limit 2 +SELECT * FROM mapdata LIMIT 2 -- !query 2 schema struct> -- !query 2 output @@ -30,7 +30,7 @@ struct> -- !query 3 -select * from testdata limit 2 + 1 +SELECT * FROM testdata LIMIT 2 + 1 -- !query 3 schema struct -- !query 3 output @@ -40,7 +40,7 @@ struct -- !query 4 -select * from testdata limit CAST(1 AS int) +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 4 schema struct -- !query 4 output @@ -48,7 +48,7 @@ struct -- !query 5 -select * from testdata limit -1 +SELECT * FROM testdata LIMIT -1 -- !query 5 schema struct<> -- !query 5 output @@ -57,35 +57,53 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 6 -select * from testdata limit key > 3 +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -select * from testdata limit true +SELECT * FROM testdata LIMIT key > 3 -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 8 -select * from testdata limit 'a' +SELECT * FROM testdata LIMIT true -- !query 8 schema struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 9 -select * from (select * from range(10) limit 5) where id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 9 schema -struct +struct<> -- !query 9 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; + + +-- !query 10 +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +-- !query 10 schema +struct +-- !query 10 output 4 + + +-- !query 11 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 11 schema +struct +-- !query 11 output +1 1 +2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out similarity index 70% rename from sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out rename to sql/core/src/test/resources/sql-tests/results/operators.sql.out index ce42c016a7100..e0236f41187ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 34 -- !query 0 @@ -224,3 +224,63 @@ select pmod(-7, 3) struct -- !query 27 output 2 + + +-- !query 28 +explain select 'a' || 1 + 2 +-- !query 28 schema +struct +-- !query 28 output +== Physical Plan == +*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] ++- Scan OneRowRelation[] + + +-- !query 29 +explain select 1 - 2 || 'b' +-- !query 29 schema +struct +-- !query 29 output +== Physical Plan == +*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] ++- Scan OneRowRelation[] + + +-- !query 30 +explain select 2 * 4 + 3 || 'b' +-- !query 30 schema +struct +-- !query 30 output +== Physical Plan == +*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] ++- Scan OneRowRelation[] + + +-- !query 31 +explain select 3 + 1 || 'a' || 4 / 2 +-- !query 31 schema +struct +-- !query 31 output +== Physical Plan == +*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] ++- Scan OneRowRelation[] + + +-- !query 32 +explain select 1 == 1 OR 'a' || 'b' == 'ab' +-- !query 32 schema +struct +-- !query 32 output +== Physical Plan == +*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] ++- Scan OneRowRelation[] + + +-- !query 33 +explain select 'a' || 'c' == 'ac' AND 2 == 3 +-- !query 33 schema +struct +-- !query 33 output +== Physical Plan == +*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] ++- Scan OneRowRelation[] diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 6961e9b65922f..8ee075118e109 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 3 -- !query 0 @@ -18,3 +18,11 @@ struct<> -- !query 1 output org.apache.spark.sql.AnalysisException requirement failed: format_string() should take at least 1 argument; line 1 pos 7 + + +-- !query 2 +select 'a' || 'b' || 'c' +-- !query 2 schema +struct +-- !query 2 output +abc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b5cd28ad0c99..8eb381b91f46d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } + + test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + val data = Seq("\u0020\u0021\u0023", "abc") + val df = data.toDF() + val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") + val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) + val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") + checkAnswer(rlike1, rlike2) + assert(rlike3.count() == 0) + } + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 2acda3f007326..3a8694839bb24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -387,7 +387,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("to_date(s)"), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - // Now with format + // now with format checkAnswer( df.select(to_date(col("t"), "yyyy-MM-dd")), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), @@ -400,7 +400,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { df.select(to_date(col("s"), "yyyy-MM-dd")), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - // now switch format + // now switch format checkAnswer( df.select(to_date(col("s"), "yyyy-dd-MM")), Seq(Row(null), Row(null), Row(Date.valueOf("2014-12-31")))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 328c5395ec91e..c2d08a06569bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + + val bdPi: BigDecimal = BigDecimal(31415925L, 7) + checkAnswer( + sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " + + s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " + + s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } test("round/bround with data frame from a local Seq of Product") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cd14d24370bad..b525c9e80ba42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -523,14 +523,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("negative in LIMIT or TABLESAMPLE") { - val expected = "The limit expression must be equal to or greater than 0, but got -1" - var e = intercept[AnalysisException] { - sql("SELECT * FROM testData TABLESAMPLE (-1 rows)") - }.getMessage - assert(e.contains(expected)) - } - test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 05637821f71f1..afccbe5cc6d19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,39 +16,36 @@ */ package org.apache.spark.sql.execution -import java.util.Locale - -import scala.language.reflectiveCalls - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { - val badRule = new SparkStrategy { - var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = - mode.toLowerCase(Locale.ROOT) match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } - } - spark.experimental.extraStrategies = badRule :: Nil + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil + }) def qe: QueryExecution = new QueryExecution(spark, OneRowRelation) // Nothing! - badRule.mode = "" assert(qe.toString.contains("OneRowRelation")) // Throw an AnalysisException - this should be captured. - badRule.mode = "exception" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new AnalysisException("exception") + }) assert(qe.toString.contains("org.apache.spark.sql.AnalysisException")) // Throw an Error - this should not be captured. - badRule.mode = "error" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new Error("error") + }) val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala new file mode 100644 index 0000000000000..25e4ca060ae02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Tests for the sameResult function for [[SparkPlan]]s. + */ +class SameResultSuite extends QueryTest with SharedSQLContext { + + test("FileSourceScanExec: different orders of data filters and partition filters") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .parquet(tmpDir) + val df = spark.read.parquet(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9")) + val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1")) + assert(plan1.sameResult(plan2)) + } + } + + private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = { + df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 908b955abbf07..b32fb90e10072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} @@ -290,4 +290,15 @@ class SparkSqlParserSuite extends PlanTest { basePlan, numPartitions = newConf.numShufflePartitions))) } + + test("pipeline concatenation") { + val concat = Concat( + Concat(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil) :: + UnresolvedAttribute("c") :: + Nil + ) + assertEqual( + "SELECT a || b || c FROM t", + Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 85aa7dbe9ed86..89cfba6c559d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -589,7 +589,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } - test("flatMapGroupsWithState - streaming with event time timeout") { + test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout val stateFunc = ( @@ -761,6 +761,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(e.getMessage === "The output mode of function should be append or update") } + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "cast(time as timestamp) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, Long)] + .groupByKey(x => x._1) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("a", "1")) + ) + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) + def testStateUpdateWithData( testName: String, stateUpdates: GroupState[Int] => Unit, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 387ec4f967233..74e15a5777916 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -22,7 +22,6 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.language.reflectiveCalls import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e95f9ea480431..b8aa067cdb903 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.language.reflectiveCalls import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 666548d1a490b..e191071efbf18 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -206,7 +206,7 @@ case class HiveTableScanExec( HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) + QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3682dc850790e..3facf9f67be9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.IOException +import java.io.{File, IOException} import java.net.URI import java.text.SimpleDateFormat import java.util.{Date, Locale, Random} @@ -97,12 +97,24 @@ case class InsertIntoHiveTable( val inputPathUri: URI = inputPath.toUri val inputPathName: String = inputPathUri.getPath val fs: FileSystem = inputPath.getFileSystem(hadoopConf) - val stagingPathName: String = + var stagingPathName: String = if (inputPathName.indexOf(stagingDir) == -1) { new Path(inputPathName, stagingDir).toString } else { inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length) } + + // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the + // staging directory needs to avoid being deleted when users set hive.exec.stagingdir + // under the table directory. + if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) && + !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) { + logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " + + "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " + + "directory.") + stagingPathName = new Path(inputPathName, ".hive-staging").toString + } + val dir: Path = fs.makeQualified( new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d6999af84eac0..2c724f8388693 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -494,4 +494,15 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef spark.table("t").write.insertInto(tableName) } } + + test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") { + // Set hive.exec.stagingdir under the table directory without start with ".". + withSQLConf("hive.exec.stagingdir" -> "./test") { + withTable("test_table") { + sql("CREATE TABLE test_table (key int)") + sql("INSERT OVERWRITE TABLE test_table SELECT 1") + checkAnswer(sql("SELECT * FROM test_table"), Row(1)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 90e037e292790..ae64cb3210b53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -164,16 +164,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e') |SELECT v.id """.stripMargin) - val plan = sql( - s""" - |SELECT * FROM $table - """.stripMargin).queryExecution.sparkPlan - val scan = plan.collectFirst { - case p: HiveTableScanExec => p - }.get + val scan = getHiveTableScanExec(s"SELECT * FROM $table") val numDataCols = scan.relation.dataCols.length scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } + + test("HiveTableScanExec canonicalization for different orders of partition filters") { + val table = "hive_tbl_part" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (id int) + |PARTITIONED BY (a int, b int) + """.stripMargin) + val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2") + val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1") + assert(scan1.sameResult(scan2)) + } + } + + private def getHiveTableScanExec(query: String): HiveTableScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: HiveTableScanExec => p + }.get + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index b70383ecde4d8..4f41b9d0a0b3c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ @@ -202,21 +201,17 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { test("block push errors are reported") { val listener = new TestBlockGeneratorListener { - @volatile var errorReported = false override def onPushBlock( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { throw new SparkException("test") } - override def onError(message: String, throwable: Throwable): Unit = { - errorReported = true - } } blockGenerator = new BlockGenerator(listener, 0, conf) blockGenerator.start() - assert(listener.errorReported === false) + assert(listener.onErrorCalled === false) blockGenerator.addData(1) eventually(timeout(1 second), interval(10 milliseconds)) { - assert(listener.errorReported === true) + assert(listener.onErrorCalled === true) } blockGenerator.stop() } @@ -243,12 +238,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { @volatile var onGenerateBlockCalled = false @volatile var onAddDataCalled = false @volatile var onPushBlockCalled = false + @volatile var onErrorCalled = false override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { pushedData.addAll(arrayBuffer.asJava) onPushBlockCalled = true } - override def onError(message: String, throwable: Throwable): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = { + onErrorCalled = true + } override def onGenerateBlock(blockId: StreamBlockId): Unit = { onGenerateBlockCalled = true }