diff --git a/.travis.yml b/.travis.yml index d94872db6437a..d7e9f8c0290e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ dist: trusty # 2. Choose language and target JDKs for parallel builds. language: java jdk: - - oraclejdk7 - oraclejdk8 # 3. Setup cache directory for SBT and Maven. diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 1afcbfcabe85f..cb2eebb9ffe6e 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -38,6 +38,6 @@ To run the SparkR unit tests on Windows, the following steps are required —ass ``` R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cc4cfa3423ced..e33d0d8e29d49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -280,7 +280,7 @@ setMethod("dtypes", #' Column Names of SparkDataFrame #' -#' Return all column names as a list. +#' Return a vector of column names. #' #' @param x a SparkDataFrame. #' @@ -338,7 +338,7 @@ setMethod("colnames", }) #' @param value a character vector. Must have the same length as the number -#' of columns in the SparkDataFrame. +#' of columns to be renamed. #' @rdname columns #' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e771a057e2444..8354f705f6dea 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -332,8 +332,10 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ), returning the result as a SparkDataFrame +#' Loads a JSON file, returning the result as a SparkDataFrame +#' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. @@ -346,6 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' df <- read.json(path, wholeFile = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -778,6 +781,7 @@ dropTempView <- function(viewName) { #' @return SparkDataFrame #' @rdname read.df #' @name read.df +#' @seealso \link{read.json} #' @export #' @examples #'\dontrun{ @@ -785,7 +789,7 @@ dropTempView <- function(viewName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 05bb95266173a..4db9cc30fb0c1 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -75,9 +75,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.svmLinear(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.svmLinear(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -220,9 +220,9 @@ function(object, path, overwrite = FALSE) { #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.logit(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -239,8 +239,7 @@ function(object, path, overwrite = FALSE) { #' #' # multinomial logistic regression #' -#' df <- createDataFrame(iris) -#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' #' } diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 8823f90775960..0ebdb5a273088 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -72,8 +72,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.bisectingKmeans(df, Class ~ Survived, k = 4) #' summary(model) #' #' # get fitted result from a bisecting k-means model @@ -82,7 +83,7 @@ setClass("LDAModel", representation(jobj = "jobj")) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -338,14 +339,14 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.kmeans(df, Class ~ Survived, k = 4, initMode = "random") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ac0578c4ab259..648d363f1a255 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -68,14 +68,14 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Freq", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -137,9 +137,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- glm(Freq ~ Sex + Age, df, family = "gaussian") #' summary(model) #' } #' @note glm since 1.5.0 diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 0d53fad061809..40a806c41bad0 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -143,14 +143,15 @@ print.summary.treeEnsemble <- function(x) { #' #' # fit a Gradient Boosted Tree Classification Model #' # label must be binary - Only binary classification is supported for GBT. -#' df <- createDataFrame(iris[iris$Species != "virginica", ]) -#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.gbt(df, Survived ~ Age + Freq, "classification") #' #' # numeric label is also supported -#' iris2 <- iris[iris$Species != "virginica", ] -#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) -#' df <- createDataFrame(iris2) -#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' t2 <- as.data.frame(Titanic) +#' t2$NumericGender <- ifelse(t2$Sex == "Male", 0, 1) +#' df <- createDataFrame(t2) +#' model <- spark.gbt(df, NumericGender ~ ., type = "classification") #' } #' @note spark.gbt since 2.1.0 setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), @@ -351,8 +352,9 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' summary(savedModel) #' #' # fit a Random Forest Classification Model -#' df <- createDataFrame(iris) -#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.randomForest(df, Survived ~ Freq + Age, "classification") #' } #' @note spark.randomForest since 2.1.0 setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ce0f5a198a259..1dd8c5ce6cb32 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -898,6 +898,12 @@ test_that("names() colnames() set the column names", { expect_equal(names(z)[3], "c") names(z)[3] <- "c2" expect_equal(names(z)[3], "c2") + + # Test subset assignment + colnames(df)[1] <- "col5" + expect_equal(colnames(df)[1], "col5") + names(df)[2] <- "col6" + expect_equal(names(df)[2], "col6") }) test_that("head() and first() return the correct data", { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index bc8bc3c26c116..43c255cff3028 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -565,11 +565,10 @@ We use a simple example to demonstrate `spark.logit` usage. In general, there ar and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. Binomial logistic regression -```{r, warning=FALSE} -df <- createDataFrame(iris) -# Create a DataFrame containing two classes -training <- df[df$Species %in% c("versicolor", "virginica"), ] -model <- spark.logit(training, Species ~ ., regParam = 0.00042) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.logit(training, Survived ~ ., regParam = 0.04741301) summary(model) ``` @@ -579,10 +578,11 @@ fitted <- predict(model, training) ``` Multinomial logistic regression against three classes -```{r, warning=FALSE} -df <- createDataFrame(iris) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. -model <- spark.logit(df, Species ~ ., regParam = 0.056) +model <- spark.logit(training, Class ~ ., regParam = 0.07815179) summary(model) ``` @@ -609,11 +609,12 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu `spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. -We use iris data set to show how to use `spark.mlp` in classification. -```{r, warning=FALSE} -df <- createDataFrame(iris) +We use Titanic data set to show how to use `spark.mlp` in classification. +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. @@ -630,7 +631,7 @@ options(ops) ``` ```{r} # make predictions use the fitted model -predictions <- predict(model, df) +predictions <- predict(model, training) head(select(predictions, predictions$prediction)) ``` @@ -769,12 +770,13 @@ predictions <- predict(rfModel, df) `spark.bisectingKmeans` is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) using a divisive (or "top-down") approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. -```{r, warning=FALSE} -df <- createDataFrame(iris) -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) summary(model) -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) ``` #### Gaussian Mixture Model @@ -912,9 +914,10 @@ testSummary ### Model Persistence The following example shows how to save/load an ML model by SparkR. -```{r, warning=FALSE} -irisDF <- createDataFrame(iris) -gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +gaussianGLM <- spark.glm(training, Freq ~ Sex + Age, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") @@ -925,7 +928,7 @@ gaussianGLM2 <- read.ml(modelPath) summary(gaussianGLM2) # Check model prediction -gaussianPredictions <- predict(gaussianGLM2, irisDF) +gaussianPredictions <- predict(gaussianGLM2, training) head(gaussianPredictions) unlink(modelPath) diff --git a/R/run-tests.sh b/R/run-tests.sh index 5e4dafaf76f3d..742a2c5ed76da 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/appveyor.yml b/appveyor.yml index 6bc66c0ea54dc..5adf1b4bedb44 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -46,7 +46,7 @@ build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 4477c9a935f21..09fc80d12d510 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -26,7 +26,6 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; -import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 87b9e8eb445aa..10a7cb1d06659 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -153,7 +153,8 @@ public void writeTo(ByteBuffer buffer) { * * Unlike getBytes this will not create a copy the array if this is a slice. */ - public @Nonnull ByteBuffer getByteBuffer() { + @Nonnull + public ByteBuffer getByteBuffer() { if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { final byte[] bytes = (byte[]) base; diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 9fe97b4d9c20c..140c52fd12f94 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -30,116 +30,117 @@ */ public class SparkFirehoseListener implements SparkListenerInterface { - public void onEvent(SparkListenerEvent event) { } - - @Override - public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { - onEvent(stageCompleted); - } - - @Override - public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { - onEvent(stageSubmitted); - } - - @Override - public final void onTaskStart(SparkListenerTaskStart taskStart) { - onEvent(taskStart); - } - - @Override - public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { - onEvent(taskGettingResult); - } - - @Override - public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { - onEvent(taskEnd); - } - - @Override - public final void onJobStart(SparkListenerJobStart jobStart) { - onEvent(jobStart); - } - - @Override - public final void onJobEnd(SparkListenerJobEnd jobEnd) { - onEvent(jobEnd); - } - - @Override - public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { - onEvent(environmentUpdate); - } - - @Override - public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { - onEvent(blockManagerAdded); - } - - @Override - public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { - onEvent(blockManagerRemoved); - } - - @Override - public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { - onEvent(unpersistRDD); - } - - @Override - public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { - onEvent(applicationStart); - } - - @Override - public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { - onEvent(applicationEnd); - } - - @Override - public final void onExecutorMetricsUpdate( - SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { - onEvent(executorMetricsUpdate); - } - - @Override - public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { - onEvent(executorAdded); - } - - @Override - public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { - onEvent(executorRemoved); - } - - @Override - public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) { - onEvent(executorBlacklisted); - } - - @Override - public final void onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted executorUnblacklisted) { - onEvent(executorUnblacklisted); - } - - @Override - public final void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) { - onEvent(nodeBlacklisted); - } - - @Override - public final void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) { - onEvent(nodeUnblacklisted); - } - - @Override - public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { - onEvent(blockUpdated); - } - - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } + + @Override + public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) { + onEvent(executorBlacklisted); + } + + @Override + public final void onExecutorUnblacklisted( + SparkListenerExecutorUnblacklisted executorUnblacklisted) { + onEvent(executorUnblacklisted); + } + + @Override + public final void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) { + onEvent(nodeBlacklisted); + } + + @Override + public final void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) { + onEvent(nodeUnblacklisted); + } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 29aca04a3d11b..f312fa2b2ddd7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -161,7 +161,9 @@ private UnsafeExternalSorter( // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addTaskCompletionListener(context -> { cleanupResources(); }); + taskContext.addTaskCompletionListener(context -> { + cleanupResources(); + }); } /** diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0fd777ed12829..f0867ecb16ea3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e083911cd..dc0d12878550a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ private[spark] class TaskContextImpl( @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailedException: Option[FetchFailedException] = None + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this._fetchFailedException = Option(fetchFailed) + } + + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 941e2d13fb28e..f475ce87540aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -82,17 +82,20 @@ class SparkHadoopUtil extends Logging { // the behavior of the old implementation of this code, for backwards compatibility. if (conf != null) { // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && - System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - val keyId = System.getenv("AWS_ACCESS_KEY_ID") - val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") - + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + if (keyId != null && accessKey != null) { hadoopConf.set("fs.s3.awsAccessKeyId", keyId) hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) hadoopConf.set("fs.s3a.access.key", keyId) hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3a.secret.key", accessKey) + + val sessionToken = System.getenv("AWS_SESSION_TOKEN") + if (sessionToken != null) { + hadoopConf.set("fs.s3a.session.token", sessionToken) + } } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5ffdedd1658ab..1e50eb6635651 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -665,7 +665,8 @@ object SparkSubmit extends CommandLineUtils { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") - printStream.println(s"System properties:\n${sysProps.mkString("\n")}") + // sysProps may contain sensitive information, so redact before printing + printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index dee77343d806d..0614d80b60e1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -84,9 +84,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + val properties = Utils.getPropertiesFromFile(filename) + properties.foreach { case (k, v) => defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } + // Property files may contain sensitive information, so redact before printing + if (verbose) { + Utils.redact(properties).foreach { case (k, v) => + SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } } } // scalastyle:on println @@ -318,7 +324,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | |Spark properties used, including those specified through | --conf and those from the properties file $propertiesFile: - |${sparkProperties.mkString(" ", "\n ", "\n")} + |${Utils.redact(sparkProperties).mkString(" ", "\n ", "\n")} """.stripMargin } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 975a6e4eeb33a..790c1ae942474 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer @@ -52,7 +53,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -78,7 +80,7 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool @@ -342,6 +344,14 @@ private[spark] class Executor( } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime @@ -402,8 +412,17 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -455,13 +474,17 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } } /** diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 2c1b5636888a8..22e26799138ba 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -113,11 +113,11 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val taskAttemptId = new TaskAttemptID(taskId, 0) // Set up the configuration object - jobContext.getConfiguration.set("mapred.job.id", jobId.toString) - jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) - jobContext.getConfiguration.setInt("mapred.task.partition", 0) + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 1e0a1e605cfbb..659ad5d0bad8c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -79,7 +79,7 @@ object SparkHadoopMapReduceWriter extends Logging { val committer = FileCommitProtocol.instantiate( className = classOf[HadoopMapReduceCommitProtocol].getName, jobId = stageId.toString, - outputPath = conf.value.get("mapred.output.dir"), + outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] committer.setupJob(jobContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 5fa6a7ed315f4..4bf8ecc383542 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -365,11 +365,11 @@ private[spark] object HadoopRDD extends Logging { val jobID = new JobID(jobTrackerId, jobId) val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) - conf.set("mapred.tip.id", taId.getTaskID.toString) - conf.set("mapred.task.id", taId.toString) - conf.setBoolean("mapred.task.is.map", true) - conf.setInt("mapred.task.partition", splitId) - conf.set("mapred.job.id", jobID.toString) + conf.set("mapreduce.task.id", taId.getTaskID.toString) + conf.set("mapreduce.task.attempt.id", taId.toString) + conf.setBoolean("mapreduce.task.ismap", true) + conf.setInt("mapreduce.task.partition", splitId) + conf.set("mapreduce.job.id", jobID.toString) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 567a3183e224c..52ce03ff8cde9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -998,7 +998,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) val jobConfiguration = job.getConfiguration - jobConfiguration.set("mapred.output.dir", path) + jobConfiguration.set("mapreduce.output.fileoutputformat.outputdir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1039,10 +1039,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) - hadoopConf.set("mapred.output.compress", "true") + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") hadoopConf.setMapOutputCompressorClass(c) - hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) - hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.codec", c.getCanonicalName) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.type", + CompressionType.BLOCK.toString) } // Use configured output committer if already set diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 08d220b40b6f3..83d87b548a430 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -48,25 +48,29 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private type StageId = Int private type PartitionId = Int private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + private case class StageState(numPartitions: Int) { + val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) + val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + } /** - * Map from active stages's id => partition id => task attempt with exclusive lock on committing - * output for that partition. + * Map from active stages's id => authorized task attempts for each partition id, which hold an + * exclusive lock on committing task output for that partition, as well as any known failed + * attempts in the stage. * * Entries are added to the top-level map when stages start and are removed they finish * (either successfully or unsuccessfully). * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() + private val stageStates = mutable.Map[StageId, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. */ def isEmpty: Boolean = { - authorizedCommittersByStage.isEmpty + stageStates.isEmpty } /** @@ -105,19 +109,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart( - stage: StageId, - maxPartitionId: Int): Unit = { - val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) - java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) - synchronized { - authorizedCommittersByStage(stage) = arr - } + private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { + stageStates(stage) = new StageState(maxPartitionId + 1) } // Called by DAGScheduler private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { - authorizedCommittersByStage.remove(stage) + stageStates.remove(stage) } // Called by DAGScheduler @@ -126,7 +124,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) partition: PartitionId, attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { - val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") return }) @@ -137,10 +135,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters(partition) == attemptNumber) { + // Mark the attempt as failed to blacklist from future commit protocol + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber + if (stageState.authorizedCommitters(partition) == attemptNumber) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -149,7 +149,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) if (isDriver) { coordinatorRef.foreach(_ send StopCoordinator) coordinatorRef = None - authorizedCommittersByStage.clear() + stageStates.clear() } } @@ -158,13 +158,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) stage: StageId, partition: PartitionId, attemptNumber: TaskAttemptNumber): Boolean = synchronized { - authorizedCommittersByStage.get(stage) match { - case Some(authorizedCommitters) => - authorizedCommitters(partition) match { + stageStates.get(stage) match { + case Some(state) if attemptFailed(state, partition, attemptNumber) => + logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + + s" partition=$partition as task attempt $attemptNumber has already failed.") + false + case Some(state) => + state.authorizedCommitters(partition) match { case NO_AUTHORIZED_COMMITTER => logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + s"partition=$partition") - authorizedCommitters(partition) = attemptNumber + state.authorizedCommitters(partition) = attemptNumber true case existingCommitter => // Coordinator should be idempotent when receiving AskPermissionToCommit. @@ -181,11 +185,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } } case None => - logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + - s"partition $partition to commit") + logDebug(s"Stage $stage has completed, so not allowing" + + s" attempt number $attemptNumber of partition $partition to commit") false } } + + private def attemptFailed( + stageState: StageState, + partition: PartitionId, + attempt: TaskAttemptNumber): Boolean = synchronized { + stageState.failures.get(partition).exists(_.contains(attempt)) + } } private[spark] object OutputCommitCoordinator { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7b726d5659e91..70213722aae4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,19 +17,14 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util._ /** @@ -137,6 +132,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still queried + // directly in the TaskRunner to check for FetchFailedExceptions. TaskContext.unset() } } @@ -156,7 +153,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b25513bea057..19ebaf817e24e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -874,7 +874,8 @@ private[spark] class TaskSetManager( // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled + && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -906,8 +907,6 @@ private[spark] class TaskSetManager( * Check for tasks to be speculated and return true if there are any. This is called periodically * by the TaskScheduler. * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a @@ -927,7 +926,8 @@ private[spark] class TaskSetManager( // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { + for (tid <- runningTasksSet) { + val info = taskInfos(tid) val index = info.index if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196ce0..265a8acfa8d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,6 +50,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 79fc2e94599c7..fa5ad4e8d81e1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -52,7 +52,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. */ - final def postToAll(event: E): Unit = { + def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 10e5233679562..1af34e3da231f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -39,6 +39,7 @@ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} @@ -2588,13 +2589,31 @@ private[spark] object Utils extends Logging { def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + redact(redactionPattern, kvs) + } + + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { kvs.map { kv => redactionPattern.findFirstIn(kv._1) - .map { ignore => (kv._1, REDACTION_REPLACEMENT_TEXT) } + .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } .getOrElse(kv) } } + /** + * Looks up the redaction regex from within the key value pairs and uses it to redact the rest + * of the key value pairs. No care is taken to make sure the redaction property itself is not + * redacted. So theoretically, the property itself could be configured to redact its own value + * when printing. + */ + def redact(kvs: Map[String, String]): Seq[(String, String)] = { + val redactionPattern = kvs.getOrElse( + SECRET_REDACTION_PATTERN.key, + SECRET_REDACTION_PATTERN.defaultValueString + ).r + redact(redactionPattern, kvs.toArray) + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 512149127d72f..01b5fb7b46684 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -358,7 +358,7 @@ public void groupByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function, Boolean> areOdd = - x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); + x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd); assertEquals(2, oddsAndEvens.count()); @@ -528,14 +528,14 @@ public void aggregateByKey() { new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), - (a, b) -> { - a.add(b); - return a; - }, - (a, b) -> { - a.addAll(b); - return a; - }).collectAsMap(); + (a, b) -> { + a.add(b); + return a; + }, + (a, b) -> { + a.addAll(b); + return a; + }).collectAsMap(); assertEquals(3, sets.size()); assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); @@ -666,8 +666,8 @@ public void javaDoubleRDDHistoGram() { assertArrayEquals(expected_counts, histogram); // SPARK-5744 assertArrayEquals( - new long[] {0}, - sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); + new long[] {0}, + sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); } private static class DoubleComparator implements Comparator, Serializable { @@ -807,7 +807,7 @@ public void mapsFromPairsToPairs() { // Regression test for SPARK-668: JavaPairRDD swapped = pairRDD.flatMapToPair( - item -> Collections.singletonList(item.swap()).iterator()); + item -> Collections.singletonList(item.swap()).iterator()); swapped.collect(); // There was never a bug here, but it's worth testing: @@ -845,11 +845,13 @@ public void mapPartitionsWithIndex() { public void getNumPartitions(){ JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); - JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( - new Tuple2<>("a", 1), - new Tuple2<>("aa", 2), - new Tuple2<>("aaa", 3) - ), 2); + JavaPairRDD rdd3 = sc.parallelizePairs( + Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), + 2); assertEquals(3, rdd1.getNumPartitions()); assertEquals(2, rdd2.getNumPartitions()); assertEquals(2, rdd3.getNumPartitions()); @@ -977,7 +979,7 @@ public void sequenceFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, @@ -1068,11 +1070,11 @@ public void writeWithNewAPIHadoopFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, + .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); JavaPairRDD output = - sc.sequenceFile(outputDir, IntWritable.class, Text.class); + sc.sequenceFile(outputDir, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1088,11 +1090,11 @@ public void readWithNewAPIHadoopFile() throws IOException { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, Job.getInstance().getConfiguration()); + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1135,10 +1137,10 @@ public void hadoopFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); + SequenceFileInputFormat.class, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1154,10 +1156,11 @@ public void hadoopFileCompressed() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, + SequenceFileOutputFormat.class, DefaultCodec.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); + SequenceFileInputFormat.class, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1263,23 +1266,23 @@ public void combineByKey() { Function2 mergeValueFunction = (v1, v2) -> v1 + v2; JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); Map results = combinedRDD.collectAsMap(); ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), - JavaConverters.collectionAsScalaIterableConverter( - Collections.>emptyList()).asScala().toSeq()); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey( - createCombinerFunction, - mergeValueFunction, - mergeValueFunction, - defaultPartitioner, - false, - new KryoSerializer(new SparkConf())); + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); results = combinedRDD.collectAsMap(); assertEquals(expected, results); } @@ -1291,11 +1294,10 @@ public void mapOnPairRDD() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); assertEquals(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(0, 2), - new Tuple2<>(1, 3), - new Tuple2<>(0, 4)), rdd3.collect()); - + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @SuppressWarnings("unchecked") @@ -1312,16 +1314,18 @@ public void collectPartitions() { assertEquals(Arrays.asList(3, 4), parts[0]); assertEquals(Arrays.asList(5, 6, 7), parts[1]); - assertEquals(Arrays.asList(new Tuple2<>(1, 1), - new Tuple2<>(2, 0)), - rdd2.collectPartitions(new int[] {0})[0]); + assertEquals( + Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), + rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); - assertEquals(Arrays.asList(new Tuple2<>(5, 1), - new Tuple2<>(6, 0), - new Tuple2<>(7, 1)), - parts2[1]); + assertEquals( + Arrays.asList( + new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), + parts2[1]); } @Test @@ -1352,7 +1356,6 @@ public void countApproxDistinctByKey() { double error = Math.abs((resCount - count) / count); assertTrue(error < 0.1); } - } @Test @@ -1531,8 +1534,8 @@ public void testRegisterKryoClasses() { SparkConf conf = new SparkConf(); conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); assertEquals( - Class1.class.getName() + "," + Class2.class.getName(), - conf.get("spark.kryo.classesToRegister")); + Class1.class.getName() + "," + Class2.class.getName(), + conf.get("spark.kryo.classesToRegister")); } @Test diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index a2d3177c5c711..5be0121db58ae 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -401,7 +401,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.set("mapred.output.format.class", classOf[TextOutputFormat[String, String]].getName) - job.set("mapred.output.dir", tempDir.getPath + "/outputDataset_old") + job.set("mapreduce.output.fileoutputformat.outputdir", tempDir.getPath + "/outputDataset_old") randomRDD.saveAsHadoopDataset(job) assert(new File(tempDir.getPath + "/outputDataset_old/part-00000").exists() === true) } @@ -415,7 +415,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) val jobConfig = job.getConfiguration - jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + jobConfig.set("mapreduce.output.fileoutputformat.outputdir", + tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index b743ff5376c49..8150fff2d018d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} @@ -27,7 +28,7 @@ import scala.concurrent.duration._ import org.mockito.ArgumentCaptor import org.mockito.Matchers.{any, eq => meq} -import org.mockito.Mockito.{inOrder, when} +import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually @@ -37,9 +38,12 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, TaskDescription} +import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } } + test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it + // may be a false positive. And we should call the uncaught exception handler. + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat + // the fetch failure as a false positive, and just do normal OOM handling. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val (failReason, uncaughtExceptionHandler) = + runTaskGetFailReasonAndExceptionHandler(taskDescription) + // make sure the task failure just looks like a OOM, not a fetch failure + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) + } + test("Gracefully handle error in task deserialization") { val conf = new SparkConf val serializer = new JavaSerializer(conf) @@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + } + + private def runTaskGetFailReasonAndExceptionHandler( + taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] + val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) - eventually(timeout(5 seconds), interval(10 milliseconds)) { + eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } } finally { @@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting val failureData = statusCaptor.getAllValues.get(1) - SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + val failReason = + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + (failReason, mockUncaughtExceptionHandler) + } +} + +class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + new Iterator[Int] { + override def hasNext: Boolean = true + override def next(): Int = { + throw new FetchFailedException( + bmAddress = BlockManagerId("1", "hostA", 1234), + shuffleId = 0, + mapId = 0, + reduceId = 0, + message = "fake fetch failure" + ) + } + } + } + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } +} + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +class FetchFailureHidingRDD( + sc: SparkContext, + val input: FetchFailureThrowingRDD, + throwOOM: Boolean) extends RDD[Int](input) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + val inItr = input.compute(split, context) + try { + Iterator(inItr.size) + } catch { + case t: Throwable => + if (throwOOM) { + throw new OutOfMemoryError("OOM while handling another exception") + } else { + throw new RuntimeException("User Exception that hides the original exception", t) + } + } + } + + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 0c362b881d912..83ed12752074d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -195,6 +195,17 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, 0 until rdd.partitions.size) } + + test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { + val stage: Int = 1 + val partition: Int = 1 + val failedAttempt: Int = 0 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d03a0c990a02b..2c2cda9f318eb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -28,6 +28,7 @@ import org.mockito.Mockito.{mock, never, spy, verify, when} import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -664,6 +665,67 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } + test("[SPARK-13931] taskSetManager should not send Resubmitted tasks after being a zombie") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val singleTask = new ShuffleMapTask(0, 0, null, new Partition { + override def index: Int = 0 + }, Seq(TaskLocation("host1", "execA")), new Properties, null) + val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + // Offer host1, which should be accepted as a PROCESS_LOCAL location + // by the one task in the task set + val task1 = manager.resourceOffer("execA", "host1", TaskLocality.PROCESS_LOCAL).get + + // Mark the task as available for speculation, and then offer another resource, + // which should be used to launch a speculative copy of the task. + manager.speculatableTasks += singleTask.partitionId + val task2 = manager.resourceOffer("execB", "host2", TaskLocality.ANY).get + + assert(manager.runningTasks === 2) + assert(manager.isZombie === false) + + val directTaskResult = new DirectTaskResult[String](null, Seq()) { + override def value(resultSer: SerializerInstance): String = "" + } + // Complete one copy of the task, which should result in the task set manager + // being marked as a zombie, because at least one copy of its only task has completed. + manager.handleSuccessfulTask(task1.taskId, directTaskResult) + assert(manager.isZombie === true) + assert(resubmittedTasks === 0) + assert(manager.runningTasks === 1) + + manager.executorLost("execB", "host2", new SlaveLost()) + assert(manager.runningTasks === 0) + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index bb6f616b18a24..896f9302ef300 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -15,8 +15,8 @@ possible**. We recommend the following: * If at all possible, run Spark on the same nodes as HDFS. The simplest way is to set up a Spark [standalone mode cluster](spark-standalone.html) on the same nodes, and configure Spark and Hadoop's memory and CPU usage to avoid interference (for Hadoop, the relevant options are -`mapred.child.java.opts` for the per-task memory and `mapred.tasktracker.map.tasks.maximum` -and `mapred.tasktracker.reduce.tasks.maximum` for number of tasks). Alternatively, you can run +`mapred.child.java.opts` for the per-task memory and `mapreduce.tasktracker.map.tasks.maximum` +and `mapreduce.tasktracker.reduce.tasks.maximum` for number of tasks). Alternatively, you can run Hadoop and Spark on a common cluster manager like [Mesos](running-on-mesos.html) or [Hadoop YARN](running-on-yarn.html). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 37862f82c3386..ab6f587e09ef2 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -629,6 +629,11 @@ others. Continuous Inverse*, Idenity, Log + + Tweedie + Zero-inflated continuous + Power link function + * Canonical Link diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index cfe835172ab45..58f2d4b531e70 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -59,6 +59,34 @@ This approach is named "ALS-WR" and discussed in the paper It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. +### Cold-start strategy + +When making predictions using an `ALSModel`, it is common to encounter users and/or items in the +test dataset that were not present during training the model. This typically occurs in two +scenarios: + +1. In production, for new users or items that have no rating history and on which the model has not +been trained (this is the "cold start problem"). +2. During cross-validation, the data is split between training and evaluation sets. When using +simple random splits as in Spark's `CrossValidator` or `TrainValidationSplit`, it is actually +very common to encounter users and/or items in the evaluation set that are not in the training set + +By default, Spark assigns `NaN` predictions during `ALSModel.transform` when a user and/or item +factor is not present in the model. This can be useful in a production system, since it indicates +a new user or item, and so the system can make a decision on some fallback to use as the prediction. + +However, this is undesirable during cross-validation, since any `NaN` predicted values will result +in `NaN` results for the evaluation metric (for example when using `RegressionEvaluator`). +This makes model selection impossible. + +Spark allows users to set the `coldStartStrategy` parameter +to "drop" in order to drop any rows in the `DataFrame` of predictions that contain `NaN` values. +The evaluation metric will then be computed over the non-`NaN` data and will be valid. +Usage of this parameter is illustrated in the example below. + +**Note:** currently the supported cold start strategies are "nan" (the default behavior mentioned +above) and "drop". Further strategies may be supported in future. + **Examples**
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index 7cbb14654e9d1..aa92c0a37c0f4 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -132,7 +132,7 @@ The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +If the `Pipeline` had more `Estimator`s, it would call the `LogisticRegressionModel`'s `transform()` method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. diff --git a/docs/quick-start.md b/docs/quick-start.md index 04ac27876252e..aa4319a23325c 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -260,7 +260,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -273,7 +273,7 @@ scalaVersion := "{{site.SCALA_VERSION}}" libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" {% endhighlight %} -For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `simple.sbt` +For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` according to the typical directory structure. Once that is in place, we can create a JAR package containing the application's code, then use the `spark-submit` script to run our program. @@ -281,7 +281,7 @@ containing the application's code, then use the `spark-submit` script to run our # Your directory layout should look like this $ find . . -./simple.sbt +./build.sbt ./src ./src/main ./src/main/scala diff --git a/docs/security.md b/docs/security.md index a4796767832b9..9eda42888637f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -12,7 +12,7 @@ Spark currently supports authentication via a shared secret. Authentication can ## Web UI The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.https.enabled` setting. +and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). ### Authentication diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2dd1ab6ef3de1..b077575155eb0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -386,8 +386,8 @@ For example: The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc. -While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in -[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and +While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in +[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and [Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets. Moreover, users are not limited to the predefined aggregate functions and can create their own. @@ -397,7 +397,7 @@ Moreover, users are not limited to the predefined aggregate functions and can cr
-Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) +Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) abstract class to implement a custom untyped aggregate function. For example, a user-defined average can look like: @@ -888,8 +888,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
@@ -901,8 +902,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
@@ -913,8 +915,9 @@ This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %} @@ -926,8 +929,9 @@ files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 33ba668b32fc2..81970b7c81f40 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -103,6 +103,8 @@ public static void main(String[] args) { ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop"); Dataset predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator() diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java index 4594e3462b2a5..ff917b720c8b6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample + * bin/run-example ml.JavaBucketedRandomProjectionLSHExample */ public class JavaBucketedRandomProjectionLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java index 0aace46939257..e164598e3ef87 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample + * bin/run-example ml.JavaMinHashLSHExample */ public class JavaMinHashLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index 3f809eba7fffb..a0979aa2d24e4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -27,7 +27,6 @@ import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -69,7 +68,8 @@ public static void main(String[] args) { .setOutputCol("words") .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - spark.udf().register("countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); + spark.udf().register( + "countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); Dataset tokenized = tokenizer.transform(sentenceDataFrame); tokenized.select("sentence", "words") diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index bd49f059b29fd..dc9970d885274 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -118,7 +118,9 @@ public static void main(String[] args) { new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))); JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map(r -> - new Tuple2, Object>(new Tuple2<>(r.user(), r.product()), r.rating()) + new Tuple2, Object>( + new Tuple2<>(r.user(), r.product()), + r.rating()) )).join(predictions).values(); // Create regression metrics object diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index adb96dd8bf00c..82bb284ea3e58 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -25,8 +25,6 @@ import java.util.Properties; // $example on:basic_parquet_example$ -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Encoders; // $example on:schema_merging$ @@ -217,12 +215,11 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an RDD[String] storing one JSON object per string. + // an Dataset[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); - JavaRDD anotherPeopleRDD = - new JavaSparkContext(spark.sparkContext()).parallelize(jsonData); - Dataset anotherPeople = spark.read().json(anotherPeopleRDD); + Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); + Dataset anotherPeople = spark.read().json(anotherPeopleDataset); anotherPeople.show(); // +---------------+----+ // | address|name| diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1a979ff5b5be2..2e7214ed56f98 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -44,7 +44,9 @@ (training, test) = ratings.randomSplit([0.8, 0.2]) # Build the recommendation model using ALS on the training data - als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") + # Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating", + coldStartStrategy="drop") model = als.fit(training) # Evaluate the model by computing the RMSE on the test data diff --git a/examples/src/main/r/ml/bisectingKmeans.R b/examples/src/main/r/ml/bisectingKmeans.R index 5fb5bfb0fa5a3..b3eaa6dd86d7d 100644 --- a/examples/src/main/r/ml/bisectingKmeans.R +++ b/examples/src/main/r/ml/bisectingKmeans.R @@ -25,20 +25,21 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-bisectingKmeans-example") # $example on$ -irisDF <- createDataFrame(iris) +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Fit bisecting k-means model with four centers -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) # get fitted result from a bisecting k-means model fitted.model <- fitted(model, "centers") # Model summary -summary(fitted.model) +head(summary(fitted.model)) # fitted values on training data -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) # $example off$ sparkR.session.stop() diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index e41af97751d3f..ee13910382c58 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -25,11 +25,12 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Model summary summary(gaussianGLM) @@ -39,14 +40,15 @@ gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) head(gaussianPredictions) # Fit a generalized linear model with glm (R-compliant) -gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") +gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -# Note: Filter out "setosa" from label column (two labels left) to match "binomial" family. -binomialDF <- filter(irisDF, irisDF$Species != "setosa") -binomialTestDF <- binomialDF -binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") +training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") +df_list2 <- randomSplit(training2, c(7,3), 2) +binomialDF <- df_list2[[1]] +binomialTestDF <- df_list2[[2]] +binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") # Model summary summary(binomialGLM) diff --git a/examples/src/main/r/ml/kmeans.R b/examples/src/main/r/ml/kmeans.R index 288e2f9724e00..824df20644fa1 100644 --- a/examples/src/main/r/ml/kmeans.R +++ b/examples/src/main/r/ml/kmeans.R @@ -26,10 +26,12 @@ sparkR.session(appName = "SparkR-ML-kmeans-example") # $example on$ # Fit a k-means model with spark.kmeans -irisDF <- suppressWarnings(createDataFrame(iris)) -kmeansDF <- irisDF -kmeansTestDF <- irisDF -kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +df_list <- randomSplit(training, c(7,3), 2) +kmeansDF <- df_list[[1]] +kmeansTestDF <- df_list[[2]] +kmeansModel <- spark.kmeans(kmeansDF, ~ Class + Sex + Age + Freq, k = 3) # Model summary diff --git a/examples/src/main/r/ml/ml.R b/examples/src/main/r/ml/ml.R index b96819418bad3..41b7867f64e36 100644 --- a/examples/src/main/r/ml/ml.R +++ b/examples/src/main/r/ml/ml.R @@ -26,11 +26,12 @@ sparkR.session(appName = "SparkR-ML-example") ############################ model read/write ############################################## # $example on:read_write$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index bb5d163608494..868f49b16f218 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -65,6 +65,8 @@ object ALSExample { val model = als.fit(training) // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop") val predictions = model.transform(test) val evaluator = new RegressionEvaluator() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala index 654535c264a35..16da4fa887aaf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.BucketedRandomProjectionLSHExample + * bin/run-example ml.BucketedRandomProjectionLSHExample */ object BucketedRandomProjectionLSHExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala index 6c1e22268ad2c..b94ab9b8bedc1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.MinHashLSHExample + * bin/run-example ml.MinHashLSHExample */ object MinHashLSHExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 66f7cb1b53f48..381e69cda841c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -111,6 +111,10 @@ object SQLDataSourceExample { private def runJsonDatasetExample(spark: SparkSession): Unit = { // $example on:json_dataset$ + // Primitive types (Int, String, etc) and Product types (case classes) encoders are + // supported by importing this when creating a Dataset. + import spark.implicits._ + // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files val path = "examples/src/main/resources/people.json" @@ -135,10 +139,10 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an RDD[String] storing one JSON object per string - val otherPeopleRDD = spark.sparkContext.makeRDD( + // an Dataset[String] storing one JSON object per string + val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) - val otherPeople = spark.read.json(otherPeopleRDD) + val otherPeople = spark.read.json(otherPeopleDataset) otherPeople.show() // +---------------+----+ // | address|name| diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index d1274a687fc70..626bde48e1a86 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.regex.Pattern; -import com.amazonaws.regions.RegionUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 26b1fda2ff511..b37b087467926 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kinesis; -import com.amazonaws.regions.RegionUtils; import com.amazonaws.services.kinesis.model.Record; import org.junit.Test; @@ -91,7 +90,8 @@ public void testCustomHandlerAwsStsCreds() { JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, - "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", "fakeSTSExternalId"); + "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", + "fakeSTSExternalId"); ssc.stop(); } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index bf6e76d7ac44e..f76b14eeeb542 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -440,19 +440,14 @@ private class LinearSVCAggregator( private val numFeatures: Int = bcFeaturesStd.value.length private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures - private val coefficients: Vector = bcCoefficients.value private var weightSum: Double = 0.0 private var lossSum: Double = 0.0 - require(numFeaturesPlusIntercept == coefficients.size, s"Dimension mismatch. Coefficients " + - s"length ${coefficients.size}, FeaturesStd length ${numFeatures}, fitIntercept: $fitIntercept") - - private val coefficientsArray = coefficients match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") } - private val gradientSumArray = Array.fill[Double](coefficientsArray.length)(0) + private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept) /** * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient @@ -463,6 +458,9 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -530,18 +528,15 @@ private class LinearSVCAggregator( this } - def loss: Double = { - if (weightSum != 0) { - lossSum / weightSum - } else 0.0 - } + def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0 def gradient: Vector = { if (weightSum != 0) { val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) result - } else Vectors.dense(Array.fill[Double](coefficientsArray.length)(0)) + } else { + Vectors.dense(new Array[Double](numFeaturesPlusIntercept)) + } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 892e00fa6041a..1a78187d4f8e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1431,7 +1431,12 @@ private class LogisticAggregator( private var weightSum = 0.0 private var lossSum = 0.0 - private val gradientSumArray = Array.fill[Double](coefficientSize)(0.0D) + @transient private lazy val coefficientsArray: Array[Double] = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " + + s"got type ${bcCoefficients.value.getClass}.)") + } + private lazy val gradientSumArray = new Array[Double](coefficientSize) if (multinomial && numClasses <= 2) { logInfo(s"Multinomial logistic regression for binary classification yields separate " + @@ -1447,7 +1452,7 @@ private class LogisticAggregator( label: Double): Unit = { val localFeaturesStd = bcFeaturesStd.value - val localCoefficients = bcCoefficients.value + val localCoefficients = coefficientsArray val localGradientArray = gradientSumArray val margin = - { var sum = 0.0 @@ -1491,7 +1496,7 @@ private class LogisticAggregator( logistic regression without pivoting. */ val localFeaturesStd = bcFeaturesStd.value - val localCoefficients = bcCoefficients.value + val localCoefficients = coefficientsArray val localGradientArray = gradientSumArray // marginOfLabel is margins(label) in the formula diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ea2dc6cfd8d31..a9c1a7ba0bc8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -580,10 +580,10 @@ private class ExpectationAggregator( private val k: Int = bcWeights.value.length private var totalCnt: Long = 0L private var newLogLikelihood: Double = 0.0 - private val newWeights: Array[Double] = new Array[Double](k) - private val newMeans: Array[DenseVector] = Array.fill(k)( + private lazy val newWeights: Array[Double] = new Array[Double](k) + private lazy val newMeans: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures)(0.0))) - private val newCovs: Array[DenseVector] = Array.fill(k)( + private lazy val newCovs: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0))) @transient private lazy val oldGaussians = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala new file mode 100644 index 0000000000000..417968d9b817d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, + FPGrowth => MLlibFPGrowth} +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Common params for FPGrowth and FPGrowthModel + */ +private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { + + /** + * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears + * more than (minSupport * size-of-the-dataset) times will be output + * Default: 0.3 + * @group param + */ + @Since("2.2.0") + val minSupport: DoubleParam = new DoubleParam(this, "minSupport", + "the minimal support level of a frequent pattern", + ParamValidators.inRange(0.0, 1.0)) + setDefault(minSupport -> 0.3) + + /** @group getParam */ + @Since("2.2.0") + def getMinSupport: Double = $(minSupport) + + /** + * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and + * partition number of the input dataset is used. + * @group expertParam + */ + @Since("2.2.0") + val numPartitions: IntParam = new IntParam(this, "numPartitions", + "Number of partitions used by parallel FP-growth", ParamValidators.gtEq[Int](1)) + + /** @group expertGetParam */ + @Since("2.2.0") + def getNumPartitions: Int = $(numPartitions) + + /** + * Minimal confidence for generating Association Rule. + * Note that minConfidence has no effect during fitting. + * Default: 0.8 + * @group param + */ + @Since("2.2.0") + val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence", + "minimal confidence for generating Association Rule", + ParamValidators.inRange(0.0, 1.0)) + setDefault(minConfidence -> 0.8) + + /** @group getParam */ + @Since("2.2.0") + def getMinConfidence: Double = $(minConfidence) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + @Since("2.2.0") + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(featuresCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + } +} + +/** + * :: Experimental :: + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * Li et al., PFP: Parallel FP-Growth for Query + * Recommendation. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * Han et al., Mining frequent patterns without + * candidate generation. Note null values in the feature column are ignored during fit(). + * + * @see + * Association rule learning (Wikipedia) + */ +@Since("2.2.0") +@Experimental +class FPGrowth @Since("2.2.0") ( + @Since("2.2.0") override val uid: String) + extends Estimator[FPGrowthModel] with FPGrowthParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("fpgrowth")) + + /** @group setParam */ + @Since("2.2.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** @group expertSetParam */ + @Since("2.2.0") + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + @Since("2.2.0") + def setMinConfidence(value: Double): this.type = set(minConfidence, value) + + /** @group setParam */ + @Since("2.2.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + @Since("2.2.0") + override def fit(dataset: Dataset[_]): FPGrowthModel = { + transformSchema(dataset.schema, logging = true) + genericFit(dataset) + } + + private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val data = dataset.select($(featuresCol)) + val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) + if (isSet(numPartitions)) { + mllibFP.setNumPartitions($(numPartitions)) + } + val parentModel = mllibFP.run(items) + val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) + + val schema = StructType(Seq( + StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + } + + @Since("2.2.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.2.0") + override def copy(extra: ParamMap): FPGrowth = defaultCopy(extra) +} + + +@Since("2.2.0") +object FPGrowth extends DefaultParamsReadable[FPGrowth] { + + @Since("2.2.0") + override def load(path: String): FPGrowth = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by FPGrowth. + * + * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + */ +@Since("2.2.0") +@Experimental +class FPGrowthModel private[ml] ( + @Since("2.2.0") override val uid: String, + @transient val freqItemsets: DataFrame) + extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { + + /** @group setParam */ + @Since("2.2.0") + def setMinConfidence(value: Double): this.type = set(minConfidence, value) + + /** @group setParam */ + @Since("2.2.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and + * "consequent" are Array[T] and "confidence" is Double. + */ + @Since("2.2.0") + @transient lazy val associationRules: DataFrame = { + AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + } + + /** + * The transform method first generates the association rules according to the frequent itemsets. + * Then for each association rule, it will examine the input items against antecedents and + * summarize the consequents as prediction. The prediction column has the same data type as the + * input column(Array[T]) and will not contain existing items in the input column. The null + * values in the feature columns are treated as empty sets. + * WARNING: internally it collects association rules to the driver and uses broadcast for + * efficiency. This may bring pressure to driver memory for large set of association rules. + */ + @Since("2.2.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + genericTransform(dataset) + } + + private def genericTransform(dataset: Dataset[_]): DataFrame = { + val rules: Array[(Seq[Any], Seq[Any])] = associationRules.select("antecedent", "consequent") + .rdd.map(r => (r.getSeq(0), r.getSeq(1))) + .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] + val brRules = dataset.sparkSession.sparkContext.broadcast(rules) + + val dt = dataset.schema($(featuresCol)).dataType + // For each rule, examine the input items and summarize the consequents + val predictUDF = udf((items: Seq[_]) => { + if (items != null) { + val itemset = items.toSet + brRules.value.flatMap(rule => + if (items != null && rule._1.forall(item => itemset.contains(item))) { + rule._2.filter(item => !itemset.contains(item)) + } else { + Seq.empty + }) + } else { + Seq.empty + }.distinct }, dt) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("2.2.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.2.0") + override def copy(extra: ParamMap): FPGrowthModel = { + val copied = new FPGrowthModel(uid, freqItemsets) + copyValues(copied, extra).setParent(this.parent) + } + + @Since("2.2.0") + override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this) +} + +@Since("2.2.0") +object FPGrowthModel extends MLReadable[FPGrowthModel] { + + @Since("2.2.0") + override def read: MLReader[FPGrowthModel] = new FPGrowthModelReader + + @Since("2.2.0") + override def load(path: String): FPGrowthModel = super.load(path) + + /** [[MLWriter]] instance for [[FPGrowthModel]] */ + private[FPGrowthModel] + class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.freqItemsets.write.parquet(dataPath) + } + } + + private class FPGrowthModelReader extends MLReader[FPGrowthModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[FPGrowthModel].getName + + override def load(path: String): FPGrowthModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val frequentItems = sparkSession.read.parquet(dataPath) + val model = new FPGrowthModel(metadata.uid, frequentItems) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +private[fpm] object AssociationRules { + + /** + * Computes the association rules with confidence above minConfidence. + * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from + * algorithms like [[FPGrowth]]. + * @param itemsCol column name for frequent itemsets + * @param freqCol column name for frequent itemsets count + * @param minConfidence minimum confidence for the result association rules + * @return a DataFrame("antecedent", "consequent", "confidence") containing the association + * rules. + */ + def getAssociationRulesFromFP[T: ClassTag]( + dataset: Dataset[_], + itemsCol: String, + freqCol: String, + minConfidence: Double): DataFrame = { + + val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd + .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) + val rows = new MLlibAssociationRules() + .setMinConfidence(minConfidence) + .run(freqItemSetRdd) + .map(r => Row(r.antecedent, r.consequent, r.confidence)) + + val dt = dataset.schema(itemsCol).dataType + val schema = StructType(Seq( + StructField("antecedent", dt, nullable = false), + StructField("consequent", dt, nullable = false), + StructField("confidence", DoubleType, nullable = false))) + val rules = dataset.sparkSession.createDataFrame(rows, schema) + rules + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 97c8655298609..799e881fad74a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -80,16 +80,47 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is - * out of integer range. + * out of integer range or contains a fractional part. */ - protected val checkedCast = udf { (n: Double) => - if (n > Int.MaxValue || n < Int.MinValue) { - throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + - s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") - } else { - n.toInt + protected[recommendation] val checkedCast = udf { (n: Any) => + n match { + case v: Int => v // Avoid unnecessary casting + case v: Number => + val intV = v.intValue + // Checks if number within Int range and has no fractional part. + if (v.doubleValue == intV) { + intV + } else { + throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " + + s"Value $n was either out of Integer range or contained a fractional part that " + + s"could not be converted.") + } + case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.") } } + + /** + * Param for strategy for dealing with unknown or new users/items at prediction time. + * This may be useful in cross-validation or production scenarios, for handling user/item ids + * the model has not seen in the training data. + * Supported values: + * - "nan": predicted value for unknown ids will be NaN. + * - "drop": rows in the input DataFrame containing unknown ids will be dropped from + * the output DataFrame containing predictions. + * Default: "nan". + * @group expertParam + */ + val coldStartStrategy = new Param[String](this, "coldStartStrategy", + "strategy for dealing with unknown or new users/items at prediction time. This may be " + + "useful in cross-validation or production scenarios, for handling user/item ids the model " + + "has not seen in the training data. Supported values: " + + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", + (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + + /** @group expertGetParam */ + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase } /** @@ -203,7 +234,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, - intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") + intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", + coldStartStrategy -> "nan") /** * Validates and transforms the input schema. @@ -248,6 +280,10 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) @@ -260,13 +296,19 @@ class ALSModel private[ml] ( Float.NaN } } - dataset + val predictions = dataset .join(userFactors, - checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + checkedCast(dataset($(userCol))) === userFactors("id"), "left") .join(itemFactors, - checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") + checkedCast(dataset($(itemCol))) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + getColdStartStrategy match { + case ALSModel.Drop => + predictions.na.drop("all", Seq($(predictionCol))) + case ALSModel.NaN => + predictions + } } @Since("1.3.0") @@ -290,6 +332,10 @@ class ALSModel private[ml] ( @Since("1.6.0") object ALSModel extends MLReadable[ALSModel] { + private val NaN = "nan" + private val Drop = "drop" + private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader @@ -432,6 +478,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @@ -451,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(checkedCast(col($(userCol)).cast(DoubleType)), - checkedCast(col($(itemCol)).cast(DoubleType)), r) + .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) @@ -671,7 +720,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { numUserBlocks: Int = 10, numItemBlocks: Int = 10, maxIter: Int = 10, - regParam: Double = 1.0, + regParam: Double = 0.1, implicitPrefs: Boolean = false, alpha: Double = 1.0, nonnegative: Boolean = false, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af7..094853b6f4802 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params fitting: Boolean): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { - SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { @@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { - dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) - .rdd.map { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), + col($(censorCol)).cast(DoubleType)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } @@ -526,7 +526,7 @@ private class AFTAggregator( private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](length) + private lazy val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fdeadaf274971..110764dc074f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -109,6 +109,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the index in the power link function. Only applicable for the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. + * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" + * package. * * @group param */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a6c29433d7303..529f66eadbcff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures */ final val isotonic: BooleanParam = new BooleanParam(this, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + + "whether the output sequence should be isotonic/increasing (true) or " + "antitonic/decreasing (false)") /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2de7e81d8d41e..45df1d9be647d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -959,7 +959,7 @@ private class LeastSquaresAggregator( @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 @transient private lazy val offset = effectiveCoefAndOffset._2 - private val gradientSumArray = Array.ofDim[Double](dim) + private lazy val gradientSumArray = Array.ofDim[Double](dim) /** * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 0f71deb9ea528..d2fe6bb2ca718 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -33,7 +33,8 @@ public class JavaDecisionTreeSuite extends SharedSparkSession { - private static int validatePrediction(List validationData, DecisionTreeModel model) { + private static int validatePrediction( + List validationData, DecisionTreeModel model) { int numCorrect = 0; for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index ee2aefee7a6db..a165d8a9345cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -123,6 +123,21 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model2.intercept !== 0.0) } + test("sparse coefficients in SVCAggregator") { + val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) + val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true) + val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrown.getMessage.contains("coefficients only supports dense")) + + bcCoefficients.destroy(blocking = false) + bcFeaturesStd.destroy(blocking = false) + } + test("linearSVC with sample weights") { def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = { assert(m1.coefficients ~== m2.coefficients absTol 0.05) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 43547a4aafcb9..d89a958eed45a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -456,6 +456,32 @@ class LogisticRegressionSuite assert(blrModel.intercept !== 0.0) } + test("sparse coefficients in LogisticAggregator") { + val bcCoefficientsBinary = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) + val binaryAgg = new LogisticAggregator(bcCoefficientsBinary, bcFeaturesStd, 2, + fitIntercept = true, multinomial = false) + val thrownBinary = withClue("binary logistic aggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + binaryAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrownBinary.getMessage.contains("coefficients only supports dense")) + + val bcCoefficientsMulti = spark.sparkContext.broadcast(Vectors.sparse(6, Array(0), Array(1.0))) + val multinomialAgg = new LogisticAggregator(bcCoefficientsMulti, bcFeaturesStd, 3, + fitIntercept = true, multinomial = true) + val thrown = withClue("multinomial logistic aggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + multinomialAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrown.getMessage.contains("coefficients only supports dense")) + bcCoefficientsBinary.destroy(blocking = false) + bcFeaturesStd.destroy(blocking = false) + bcCoefficientsMulti.destroy(blocking = false) + } + test("overflow prediction for multiclass") { val model = new LogisticRegressionModel("mLogReg", Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala new file mode 100644 index 0000000000000..74c7461401905 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = FPGrowthSuite.getFPGrowthData(spark) + } + + test("FPGrowth fit and transform with different data types") { + Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => + val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val model = new FPGrowth().setMinSupport(0.5).fit(data) + val generatedRules = model.setMinConfidence(0.5).associationRules + val expectedRules = spark.createDataFrame(Seq( + (Array("2"), Array("1"), 1.0), + (Array("1"), Array("2"), 0.75) + )).toDF("antecedent", "consequent", "confidence") + .withColumn("antecedent", col("antecedent").cast(ArrayType(dt))) + .withColumn("consequent", col("consequent").cast(ArrayType(dt))) + assert(expectedRules.sort("antecedent").rdd.collect().sameElements( + generatedRules.sort("antecedent").rdd.collect())) + + val transformed = model.transform(data) + val expectedTransformed = spark.createDataFrame(Seq( + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "3"), Array(2)) + )).toDF("id", "features", "prediction") + .withColumn("features", col("features").cast(ArrayType(dt))) + .withColumn("prediction", col("prediction").cast(ArrayType(dt))) + assert(expectedTransformed.collect().toSet.equals( + transformed.collect().toSet)) + } + } + + test("FPGrowth getFreqItems") { + val model = new FPGrowth().setMinSupport(0.7).fit(dataset) + val expectedFreq = spark.createDataFrame(Seq( + (Array("1"), 4L), + (Array("2"), 3L), + (Array("1", "2"), 3L), + (Array("2", "1"), 3L) // duplicate as the items sequence is not guaranteed + )).toDF("items", "expectedFreq") + val freqItems = model.freqItemsets + + val checkDF = freqItems.join(expectedFreq, "items") + assert(checkDF.count() == 3 && checkDF.filter(col("freq") === col("expectedFreq")).count() == 3) + } + + test("FPGrowth getFreqItems with Null") { + val df = spark.createDataFrame(Seq( + (1, Array("1", "2", "3", "5")), + (2, Array("1", "2", "3", "4")), + (3, null.asInstanceOf[Array[String]]) + )).toDF("id", "features") + val model = new FPGrowth().setMinSupport(0.7).fit(dataset) + val prediction = model.transform(df) + assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) + } + + test("FPGrowth parameter check") { + val fpGrowth = new FPGrowth().setMinSupport(0.4567) + val model = fpGrowth.fit(dataset) + .setMinConfidence(0.5678) + assert(fpGrowth.getMinSupport === 0.4567) + assert(model.getMinConfidence === 0.5678) + } + + test("read/write") { + def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { + assert(model.freqItemsets.sort("items").collect() === + model2.freqItemsets.sort("items").collect()) + } + val fPGrowth = new FPGrowth() + testEstimatorAndModelReadWrite( + fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + } + +} + +object FPGrowthSuite { + + def getFPGrowthData(spark: SparkSession): DataFrame = { + spark.createDataFrame(Seq( + (0, Array("1", "2")), + (0, Array("1", "2")), + (0, Array("1", "2")), + (0, Array("1", "3")) + )).toDF("id", "features") + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "minSupport" -> 0.321, + "minConfidence" -> 0.456, + "numPartitions" -> 5, + "predictionCol" -> "myPrediction" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b923bacce23ca..c8228dd004374 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -205,6 +206,70 @@ class ALSSuite assert(decompressed.toSet === expected) } + test("CheckedCast") { + val checkedCast = new ALS().checkedCast + val df = spark.range(1) + + withClue("Valid Integer Ids") { + df.select(checkedCast(lit(123))).collect() + } + + withClue("Valid Long Ids") { + df.select(checkedCast(lit(1231L))).collect() + } + + withClue("Valid Decimal Ids") { + df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect() + } + + withClue("Valid Double Ids") { + df.select(checkedCast(lit(123.0))).collect() + } + + val msg = "either out of Integer range or contained a fractional part" + withClue("Invalid Long: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000L))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Type") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit("123.1"))).collect() + } + assert(e.getMessage.contains("was not numeric")) + } + } + /** * Generates an explicit feedback dataset for testing ALS. * @param numUsers number of users @@ -498,8 +563,8 @@ class ALSSuite (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== - act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== + act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } // check user/item ids falling outside of Int range @@ -510,34 +575,35 @@ class ALSSuite (0, big, small, 0, big, small, 2.0), (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + val msg = "either out of Integer range or contained a fractional part" withClue("fit should fail when ids exceed integer range. ") { assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) assert(intercept[SparkException] { model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) } } @@ -547,6 +613,53 @@ class ALSSuite ALS.train(ratings) } } + + test("ALS cold start user/item prediction strategy") { + val spark = this.spark + import spark.implicits._ + import org.apache.spark.sql.functions._ + + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val data = ratings.toDF + val knownUser = data.select(max("user")).as[Int].first() + val unknownUser = knownUser + 10 + val knownItem = data.select(max("item")).as[Int].first() + val unknownItem = knownItem + 20 + val test = Seq( + (unknownUser, unknownItem), + (knownUser, unknownItem), + (unknownUser, knownItem), + (knownUser, knownItem) + ).toDF("user", "item") + + val als = new ALS().setMaxIter(1).setRank(1) + // default is 'nan' + val defaultModel = als.fit(data) + val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() + assert(defaultPredictions.length == 4) + assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) + assert(!defaultPredictions.last.isNaN) + + // check 'drop' strategy should filter out rows with unknown users/items + val dropPredictions = defaultModel + .setColdStartStrategy("drop") + .transform(test) + .select("prediction").as[Float].collect() + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + } + + test("case insensitive cold start param value") { + val spark = this.spark + import spark.implicits._ + val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1) + val data = ratings.toDF + val model = new ALS().fit(data) + Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => + model.setColdStartStrategy(s).transform(data) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38d..3cd4b0ac308ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types._ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite } } - test("should support all NumericType labels") { + test("should support all NumericType labels, and not support other types") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( aft, spark, isClassification = false) { (expected, actual) => @@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType censors, and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF("label", "features") + .withColumn("censor", lit(0.0)) + val aft = new AFTSurvivalRegression().setMaxIter(1) + val expected = aft.fit(df) + + val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0)) + types.foreach { t => + val actual = aft.fit(df.select(col("label"), col("features"), + col("censor").cast(t))) + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + + val dfWithStringCensors = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3), "0") + )).toDF("label", "features", "censor") + val thrown = intercept[IllegalArgumentException] { + aft.fit(dfWithStringCensors) + } + assert(thrown.getMessage.contains( + "Column censor must be of type NumericType but was actually of type StringType")) + } + test("numerical stability of standardization") { val trainer = new AFTSurvivalRegression() val model1 = trainer.fit(datasetUnivariate) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 511686fb4f37f..56b8c0b95e8a4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,6 +55,9 @@ object MimaExcludes { // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"), + // [SPARK-19267] Fetch Failure handling robust to user error handling + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"), + // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$10"), diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9331e74eede59..14c51a306e1c2 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -93,13 +93,15 @@ def keyword_only(func): """ A decorator that forces keyword arguments in the wrapped method and saves actual input keyword arguments in `_input_kwargs`. + + .. note:: Should only be used to wrap a method where first arg is `self` """ @wraps(func) - def wrapper(*args, **kwargs): - if len(args) > 1: + def wrapper(self, *args, **kwargs): + if len(args) > 0: raise TypeError("Method %s forces keyword arguments." % func.__name__) - wrapper._input_kwargs = kwargs - return func(*args, **kwargs) + self._input_kwargs = kwargs + return func(self, **kwargs) return wrapper diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ac40fceaf8e96..b4fc357e42d71 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -124,7 +124,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "org.apache.spark.ml.classification.LinearSVC", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True, standardization=True, threshold=0.0, aggregationDepth=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -140,7 +140,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre aggregationDepth=2): Sets params for Linear SVM Classifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -266,7 +266,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -286,7 +286,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) self._checkThresholdConsistency() return self @@ -760,7 +760,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -778,7 +778,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre seed=None) Sets params for the DecisionTreeClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -890,7 +890,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -908,7 +908,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) Sets params for linear classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1031,7 +1031,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1047,7 +1047,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) Sets params for Gradient Boosted Tree Classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1174,7 +1174,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) self._setDefault(smoothing=1.0, modelType="multinomial") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1188,7 +1188,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1329,7 +1329,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1343,7 +1343,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre solver="l-bfgs", initialWeights=None) Sets params for MultilayerPerceptronClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1519,7 +1519,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred classifier=None) """ super(OneVsRest, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -1529,7 +1529,7 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): Sets params for OneVsRest. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index c6c1a0033190e..88ac7e275e386 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -224,7 +224,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", self.uid) self._setDefault(k=2, tol=0.01, maxIter=100) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -240,7 +240,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for GaussianMixture. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -414,7 +414,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -430,7 +430,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for KMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -591,7 +591,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", self.uid) self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -603,7 +603,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", maxIter= seed=None, k=4, minDivisibleClusterSize=1.0) Sets params for BisectingKMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -916,7 +916,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, subsamplingRate=0.05, optimizeDocConcentration=True, topicDistributionCol="topicDistribution", keepLastCheckpoint=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -941,7 +941,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt Sets params for LDA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 7aa16fa5b90f2..7cb8d62f212cb 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -148,7 +148,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -174,7 +174,7 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") Sets params for binary classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -226,7 +226,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -252,7 +252,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse") Sets params for regression evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -299,7 +299,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.5.0") @@ -325,7 +325,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="f1") Sets params for multiclass classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c2eafbefcdec1..92f8549e9cb9e 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -94,7 +94,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self._setDefault(threshold=0.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -104,7 +104,7 @@ def setParams(self, threshold=0.0, inputCol=None, outputCol=None): setParams(self, threshold=0.0, inputCol=None, outputCol=None) Sets params for this Binarizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -258,14 +258,14 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None): """ - __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, + __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, \ bucketLength=None) """ super(BucketedRandomProjectionLSH, self).__init__() self._java_obj = \ self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -277,7 +277,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None) Sets params for this BucketedRandomProjectionLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.2.0") @@ -370,7 +370,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -380,7 +380,7 @@ def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="e setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -484,7 +484,7 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -496,7 +496,7 @@ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, input outputCol=None) Set the params for the CountVectorizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -616,7 +616,7 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) self._setDefault(inverse=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -626,7 +626,7 @@ def setParams(self, inverse=False, inputCol=None, outputCol=None): setParams(self, inverse=False, inputCol=None, outputCol=None) Sets params for this DCT. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -680,7 +680,7 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -690,7 +690,7 @@ def setParams(self, scalingVec=None, inputCol=None, outputCol=None): setParams(self, scalingVec=None, inputCol=None, outputCol=None) Sets params for this ElementwiseProduct. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -750,7 +750,7 @@ def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=N super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -760,7 +760,7 @@ def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol= setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None) Sets params for this HashingTF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -823,7 +823,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self._setDefault(minDocFreq=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -833,7 +833,7 @@ def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): setParams(self, minDocFreq=0, inputCol=None, outputCol=None) Sets params for this IDF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -913,7 +913,7 @@ def __init__(self, inputCol=None, outputCol=None): super(MaxAbsScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) self._setDefault() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -923,7 +923,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this MaxAbsScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1011,7 +1011,7 @@ def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): super(MinHashLSH, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1021,7 +1021,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1) Sets params for this MinHashLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1106,7 +1106,7 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) self._setDefault(min=0.0, max=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1116,7 +1116,7 @@ def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) Sets params for this MinMaxScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1224,7 +1224,7 @@ def __init__(self, n=2, inputCol=None, outputCol=None): super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) self._setDefault(n=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1234,7 +1234,7 @@ def setParams(self, n=2, inputCol=None, outputCol=None): setParams(self, n=2, inputCol=None, outputCol=None) Sets params for this NGram. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -1288,7 +1288,7 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self._setDefault(p=2.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1298,7 +1298,7 @@ def setParams(self, p=2.0, inputCol=None, outputCol=None): setParams(self, p=2.0, inputCol=None, outputCol=None) Sets params for this Normalizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1363,12 +1363,12 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, @keyword_only def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ - __init__(self, includeFirst=True, inputCol=None, outputCol=None) + __init__(self, dropLast=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) self._setDefault(dropLast=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1378,7 +1378,7 @@ def setParams(self, dropLast=True, inputCol=None, outputCol=None): setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1434,7 +1434,7 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self._setDefault(degree=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1444,7 +1444,7 @@ def setParams(self, degree=2, inputCol=None, outputCol=None): setParams(self, degree=2, inputCol=None, outputCol=None) Sets params for this PolynomialExpansion. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1540,7 +1540,7 @@ def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0. self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", self.uid) self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1552,7 +1552,7 @@ def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0 handleInvalid="error") Set the params for the QuantileDiscretizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -1665,7 +1665,7 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1677,7 +1677,7 @@ def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None, toLowercase=True) Sets params for this RegexTokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1768,7 +1768,7 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1778,7 +1778,7 @@ def setParams(self, statement=None): setParams(self, statement=None) Sets params for this SQLTransformer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1847,7 +1847,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self._setDefault(withMean=False, withStd=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1857,7 +1857,7 @@ def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) Sets params for this StandardScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1963,7 +1963,7 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1973,7 +1973,7 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -2021,7 +2021,7 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2031,7 +2031,7 @@ def setParams(self, inputCol=None, outputCol=None, labels=None): setParams(self, inputCol=None, outputCol=None, labels=None) Sets params for this IndexToString. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2085,7 +2085,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive= self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), caseSensitive=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2095,7 +2095,7 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) Sets params for this StopWordRemover. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2178,7 +2178,7 @@ def __init__(self, inputCol=None, outputCol=None): """ super(Tokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2188,7 +2188,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this Tokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2222,7 +2222,7 @@ def __init__(self, inputCols=None, outputCol=None): """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2232,7 +2232,7 @@ def setParams(self, inputCols=None, outputCol=None): setParams(self, inputCols=None, outputCol=None) Sets params for this VectorAssembler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2320,7 +2320,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self._setDefault(maxCategories=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2330,7 +2330,7 @@ def setParams(self, maxCategories=20, inputCol=None, outputCol=None): setParams(self, maxCategories=20, inputCol=None, outputCol=None) Sets params for this VectorIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2435,7 +2435,7 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) self._setDefault(indices=[], names=[]) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2445,7 +2445,7 @@ def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): setParams(self, inputCol=None, outputCol=None, indices=None, names=None): Sets params for this VectorSlicer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2558,7 +2558,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, windowSize=5, maxSentenceLength=1000) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2570,7 +2570,7 @@ def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000) Sets params for this Word2Vec. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2718,7 +2718,7 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2728,7 +2728,7 @@ def setParams(self, k=None, inputCol=None, outputCol=None): setParams(self, k=None, inputCol=None, outputCol=None) Set params for this PCA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -2858,7 +2858,7 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label", super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) self._setDefault(forceIndexLabel=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2870,7 +2870,7 @@ def setParams(self, formula=None, featuresCol="features", labelCol="label", forceIndexLabel=False) Sets params for RFormula. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -3017,7 +3017,7 @@ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05, fwe=0.05) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -3031,7 +3031,7 @@ def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, fdr=0.05, fwe=0.05) Sets params for this ChiSqSelector. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.1.0") diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a78e3b49fbcfc..4aac6a4466b54 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -58,7 +58,7 @@ def __init__(self, stages=None): __init__(self, stages=None) """ super(Pipeline, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @since("1.3.0") @@ -85,7 +85,7 @@ def setParams(self, stages=None): setParams(self, stages=None) Sets params for Pipeline. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e28d38bd19f80..8bc899a0788bb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) @@ -145,8 +151,8 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK") - kwargs = self.__init__._input_kwargs + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -155,16 +161,16 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -332,6 +338,20 @@ def getFinalStorageLevel(self): """ return self.getOrDefault(self.finalStorageLevel) + @since("2.2.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b42e807069802..b199bf282e4f2 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -108,7 +108,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -122,7 +122,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre standardization=True, solver="auto", weightCol=None, aggregationDepth=2) Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -464,7 +464,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) self._setDefault(isotonic=True, featureIndex=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -475,7 +475,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre weightCol=None, isotonic=True, featureIndex=0): Set the params for IsotonicRegression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -704,7 +704,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -720,7 +720,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -895,7 +895,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, numTrees=20, featureSubsetStrategy="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -913,7 +913,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre featureSubsetStrategy="auto") Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1022,7 +1022,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1040,7 +1040,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance") Sets params for Gradient Boosted Tree Regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1171,7 +1171,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], maxIter=100, tol=1E-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1186,7 +1186,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ quantilesCol=None, aggregationDepth=2): """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1366,7 +1366,7 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1380,7 +1380,7 @@ def setParams(self, labelCol="label", featuresCol="features", predictionCol="pre regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) Sets params for generalized linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 293c6c0b0f36a..352416055791e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -250,7 +250,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -259,7 +259,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -271,7 +271,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(OtherTestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -280,7 +280,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2dcc99cef8aa2..ffeb4459e1aac 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -186,7 +186,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -198,7 +198,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num seed=None): Sets params for cross validator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -346,7 +346,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @@ -358,7 +358,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, tra seed=None): Sets params for the train validation split. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b5e5b18bcbefa..45fb9b7591529 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -163,8 +163,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -308,7 +308,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -385,6 +385,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -398,7 +400,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bd19fd4e385b4..625fb9ba385af 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -433,8 +433,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -562,7 +562,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -637,6 +637,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -652,7 +654,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd083e4868cd6..e943f8da3db14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -437,12 +437,19 @@ def test_udf_with_order_by_and_limit(self): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e908b1e739bb3..c6c87a9ea5555 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -58,6 +58,7 @@ from StringIO import StringIO +from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -1347,7 +1348,7 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapred.input.dir": hellopath} + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -1366,7 +1367,7 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapred.input.dir": hellopath} + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -1515,12 +1516,12 @@ def test_oldhadoop(self): conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapred.output.dir": basepath + "/olddataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapred.input.dir": basepath + "/olddataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1547,14 +1548,14 @@ def test_newhadoop(self): self.assertEqual(result, data) conf = { - "mapreduce.outputformat.class": + "mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.Text", - "mapred.output.dir": basepath + "/newdataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" } self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1584,16 +1585,16 @@ def test_newhadoop_with_array(self): self.assertEqual(result, array_data) conf = { - "mapreduce.outputformat.class": + "mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapred.output.dir": basepath + "/newdataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" } self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( conf, valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1663,18 +1664,19 @@ def test_reserialization(self): conf4 = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.dir": basepath + "/reserialize/dataset"} + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} rdd.saveAsHadoopDataset(conf4) result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) self.assertEqual(result4, data) - conf5 = {"mapreduce.outputformat.class": + conf5 = {"mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.dir": basepath + "/reserialize/newdataset"} + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } rdd.saveAsNewAPIHadoopDataset(conf5) result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) @@ -2160,6 +2162,44 @@ def test_memory_conf(self): sc.stop() +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/test_support/sql/ages_newlines.csv b/python/test_support/sql/ages_newlines.csv new file mode 100644 index 0000000000000..d19f6731625fa --- /dev/null +++ b/python/test_support/sql/ages_newlines.csv @@ -0,0 +1,6 @@ +Joe,20,"Hi, +I am Jeo" +Tom,30,"My name is Tom" +Hyukjin,25,"I am Hyukjin + +I love Spark!" diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 792ade8f0bdbd..38b082ac01197 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, SparkUncaughtExceptionHandler, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -97,6 +97,7 @@ private[mesos] object MesosClusterDispatcher with CommandLineUtils { override def main(args: Array[String]) { + Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f555072c3842a..f69c223ab9b6d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -54,14 +54,17 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( with org.apache.mesos.Scheduler with MesosSchedulerUtils { - val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures + // Blacklist a slave after this many failures + private val MAX_SLAVE_FAILURES = 2 - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) - val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + // Maximum number of cores to acquire + private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) - val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + + private val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") @@ -75,10 +78,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new mutable.HashMap[String, Int] - val gpusByTaskId = new mutable.HashMap[String, Int] - var totalCoresAcquired = 0 - var totalGpusAcquired = 0 + private val coresByTaskId = new mutable.HashMap[String, Int] + private val gpusByTaskId = new mutable.HashMap[String, Int] + private var totalCoresAcquired = 0 + private var totalGpusAcquired = 0 // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because @@ -108,7 +111,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) + private val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = @@ -139,7 +142,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( securityManager.isAuthenticationEnabled()) } - var nextMesosTaskId = 0 + private var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -256,7 +259,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def sufficientResourcesRegistered(): Boolean = { - totalCoresAcquired >= maxCores * minRegisteredRatio + totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } override def disconnected(d: org.apache.mesos.SchedulerDriver) {} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index cdb3b68489654..78346e9744957 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.Promise import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} @@ -37,8 +35,8 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.mesos.Utils._ @@ -304,25 +302,29 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } test("weburi is set in created scheduler driver") { - setBackend() + initializeSparkConf() + sc = new SparkContext(sparkConf) + val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val securityManager = mock[SecurityManager] val backend = new MesosCoarseGrainedSchedulerBackend( - taskScheduler, sc, "master", securityManager) { + taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( - masterUrl: String, - scheduler: Scheduler, - sparkUser: String, - appName: String, - conf: SparkConf, - webuiUrl: Option[String] = None, - checkpoint: Option[Boolean] = None, - failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = { + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { markRegistered() assert(webuiUrl.isDefined) assert(webuiUrl.get.equals("http://webui")) @@ -419,37 +421,11 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!dockerInfo.getForcePullImage) } - test("Do not call removeExecutor() after backend is stopped") { - setBackend() - - // launches a task on a valid offer - val offers = List(Resources(backend.executorMemory(sc), 1)) - offerResources(offers) - verifyTaskLaunched(driver, "o1") - - // launches a thread simulating status update - val statusUpdateThread = new Thread { - override def run(): Unit = { - while (!stopCalled) { - Thread.sleep(100) - } - - val status = createTaskStatus("0", "s1", TaskState.TASK_FINISHED) - backend.statusUpdate(driver, status) - } - }.start - - backend.stop() - // Any method of the backend involving sending messages to the driver endpoint should not - // be called after the backend is stopped. - verify(driverEndpoint, never()).askSync(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) - } - test("mesos supports spark.executor.uri") { val url = "spark.spark.spark.com" setBackend(Map( "spark.executor.uri" -> url - ), false) + ), null) val (mem, cpu) = (backend.executorMemory(sc), 4) @@ -465,7 +441,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map( "spark.mesos.fetcherCache.enable" -> "true", "spark.executor.uri" -> url - ), false) + ), null) val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) val launchedTasks = verifyTaskLaunched(driver, "o1") @@ -479,7 +455,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map( "spark.mesos.fetcherCache.enable" -> "false", "spark.executor.uri" -> url - ), false) + ), null) val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) val launchedTasks = verifyTaskLaunched(driver, "o1") @@ -504,8 +480,31 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(networkInfos.get(0).getName == "test-network-name") } + test("supports spark.scheduler.minRegisteredResourcesRatio") { + val expectedCores = 1 + setBackend(Map( + "spark.cores.max" -> expectedCores.toString, + "spark.scheduler.minRegisteredResourcesRatio" -> "1.0")) + + val offers = List(Resources(backend.executorMemory(sc), expectedCores)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + assert(!backend.isReady) + + registerMockExecutor(launchedTasks(0).getTaskId.getValue, "s1", expectedCores) + assert(backend.isReady) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) + private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { + val mockEndpointRef = mock[RpcEndpointRef] + val mockAddress = mock[RpcAddress] + val message = RegisterExecutor(executorId, mockEndpointRef, slaveId, cores, Map.empty) + + backend.driverEndpoint.askSync[Boolean](message) + } + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -534,8 +533,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver, - shuffleClient: MesosExternalShuffleClient, - endpoint: RpcEndpointRef): MesosCoarseGrainedSchedulerBackend = { + shuffleClient: MesosExternalShuffleClient) = { val securityManager = mock[SecurityManager] val backend = new MesosCoarseGrainedSchedulerBackend( @@ -553,9 +551,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient - override protected def createDriverEndpointRef( - properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint - // override to avoid race condition with the driver thread on `mesosDriver` override def startScheduler(newDriver: SchedulerDriver): Unit = { mesosDriver = newDriver @@ -571,31 +566,35 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend } - private def setBackend(sparkConfVars: Map[String, String] = null, - setHome: Boolean = true) { + private def initializeSparkConf( + sparkConfVars: Map[String, String] = null, + home: String = "/path"): Unit = { sparkConf = (new SparkConf) .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .set("spark.mesos.driver.webui.url", "http://webui") - if (setHome) { - sparkConf.setSparkHome("/path") + if (home != null) { + sparkConf.setSparkHome(home) } if (sparkConfVars != null) { sparkConf.setAll(sparkConfVars) } + } + private def setBackend(sparkConfVars: Map[String, String] = null, home: String = "/path") { + initializeSparkConf(sparkConfVars, home) sc = new SparkContext(sparkConf) driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] - driverEndpoint = mock[RpcEndpointRef] - when(driverEndpoint.ask(any())(any())).thenReturn(Promise().future) - backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index afea4676893ed..791e8d80e6cba 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -117,7 +117,7 @@ public void setNullShort(int ordinal) { public void setNullInt(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0); + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } public void setNullLong(int ordinal) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c477cb48d0b07..6d569b612de7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: - ResolveInlineTables :: + ResolveInlineTables(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 36ab8b8527b44..7529f9028498c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.SimpleCatalogRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 7323197b10f6e..d5b3ea8c37c66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -object ResolveInlineTables extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] { InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { - if (e.dataType.sameType(targetType)) { - e.eval() + val castedExpr = if (e.dataType.sameType(targetType)) { + e } else { - Cast(e, targetType).eval() + Cast(e, targetType) } + castedExpr.transform { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + }.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index a3a4ab37ea714..31eded4deba7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -244,11 +244,13 @@ abstract class ExternalCatalog { * @param db database name * @param table table name * @param predicates partition-pruning predicates + * @param defaultTimeZoneId default timezone id to parse partition values of TimestampType */ def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] // -------------------------------------------------------------------------- // Functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 6bb2b2d4ff72e..340e8451f14ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -544,7 +544,8 @@ class InMemoryCatalog( override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { // TODO: Provide an implementation throw new UnsupportedOperationException( "listPartitionsByFilter is not implemented for InMemoryCatalog") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 0230626a6644e..f6412e42c13d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -592,7 +592,12 @@ class SessionCatalog( child = parser.parsePlan(viewText)) SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) } else { - SubqueryAlias(table, SimpleCatalogRelation(metadata), None) + val tableRelation = CatalogRelation( + metadata, + // we assume all the columns are nullable. + metadata.dataSchema.asNullable.toAttributes, + metadata.partitionSchema.asNullable.toAttributes) + SubqueryAlias(table, tableRelation, None) } } else { SubqueryAlias(table, tempTables(table), None) @@ -836,7 +841,7 @@ class SessionCatalog( val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) - externalCatalog.listPartitionsByFilter(db, table, predicates) + externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 2b3b575b4c06e..887caf07d1481 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.catalyst.catalog import java.util.Date -import scala.collection.mutable +import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.StructType @@ -112,11 +113,11 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) + val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - // TODO: use correct timezone for partition values. - Cast(Literal(spec(field.name)), field.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() }) } } @@ -349,36 +350,43 @@ object CatalogTypes { /** - * An interface that is implemented by logical plans to return the underlying catalog table. - * If we can in the future consolidate SimpleCatalogRelation and MetastoreRelation, we should - * probably remove this interface. + * A [[LogicalPlan]] that represents a table. */ -trait CatalogRelation { - def catalogTable: CatalogTable - def output: Seq[Attribute] -} +case class CatalogRelation( + tableMeta: CatalogTable, + dataCols: Seq[Attribute], + partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + assert(tableMeta.identifier.database.isDefined) + assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) + assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) + + // The partition column should always appear after data columns. + override def output: Seq[Attribute] = dataCols ++ partitionCols + + def isPartitioned: Boolean = partitionCols.nonEmpty + + override def equals(relation: Any): Boolean = relation match { + case other: CatalogRelation => tableMeta == other.tableMeta && output == other.output + case _ => false + } + override def hashCode(): Int = { + Objects.hashCode(tableMeta.identifier, output) + } -/** - * A [[LogicalPlan]] that wraps [[CatalogTable]]. - * - * Note that in the future we should consolidate this and HiveCatalogRelation. - */ -case class SimpleCatalogRelation( - metadata: CatalogTable) - extends LeafNode with CatalogRelation { - - override def catalogTable: CatalogTable = metadata - - override lazy val resolved: Boolean = false - - override val output: Seq[Attribute] = { - val (partCols, dataCols) = metadata.schema.toAttributes - // Since data can be dumped in randomly with no validation, everything is nullable. - .map(_.withNullability(true).withQualifier(Some(metadata.identifier.table))) - .partition { a => - metadata.partitionColumnNames.contains(a.name) - } - dataCols ++ partCols + /** Only compare table identifier. */ + override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + + override def computeStats(conf: CatalystConf): Statistics = { + // For data source tables, we will create a `LogicalRelation` and won't call this method, for + // hive serde tables, we will always generate a statistics. + // TODO: unify the table stats generation. + tableMeta.stats.map(_.toPlanStats(output)).getOrElse { + throw new IllegalStateException("table stats must be specified.") + } } + + override def newInstance(): LogicalPlan = copy( + dataCols = dataCols.map(_.newInstance()), + partitionCols = partitionCols.map(_.newInstance())) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 1504a522798b0..9f4a0f2b7017a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -28,7 +28,7 @@ object AttributeMap { } } -class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) +class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 26697e9867b35..a3cc4529b5456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -63,7 +63,9 @@ case class TableIdentifier(table: String, database: Option[String]) } /** A fully qualified identifier for a table (i.e., database.tableName) */ -case class QualifiedTableName(database: String, name: String) +case class QualifiedTableName(database: String, name: String) { + override def toString: String = s"$database.$name" +} object TableIdentifier { def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4f593c894acd2..21d1cd5932620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -457,7 +457,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // join is not always picked from its children, but can also be null. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) => + case j @ Join(_, _, Inner, _) if !stop => j.transformExpressions(replaceFoldable) // We can fold the projections an expand holds. However expand changes the output columns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 37f29ba68a206..0c928832d7d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.{HashSet, Map} +import scala.collection.immutable.HashSet import scala.collection.mutable import org.apache.spark.internal.Logging @@ -31,15 +29,16 @@ import org.apache.spark.sql.types._ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + private val childStats = plan.child.stats(catalystConf) + /** - * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, column c has - * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), - * we need to set the column's [min, max] value to [40, 100] after we evaluate the - * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * We will update the corresponding ColumnStats for a column after we apply a predicate condition. + * For example, column c has [min, max] value as [0, 100]. In a range condition such as + * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we + * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50] * after we evaluate the second condition c <= 50. */ - private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + private val colStatsMap = new ColumnStatsMap /** * Returns an option of Statistics for a Filter logical plan node. @@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return Option[Statistics] When there is no statistics collected, it returns None. */ def estimate: Option[Statistics] = { - // We first copy child node's statistics and then modify it based on filter selectivity. - val stats: Statistics = plan.child.stats(catalystConf) - if (stats.rowCount.isEmpty) return None + if (childStats.rowCount.isEmpty) return None // save a mutable copy of colStats so that we can later change it recursively - mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + colStatsMap.setInitValues(childStats.attributeStats) // estimate selectivity of this filter predicate val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { @@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => 1.0 } - // attributeStats has mapping Attribute-to-ColumnStat. - // mutableColStats has mapping ExprId-to-ColumnStat. - // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - // copy mutableColStats contents to an immutable AttributeMap. - val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) - val newColStats = AttributeMap(mutableAttributeStats.toSeq) + val newColStats = colStatsMap.toColumnStats val filteredRowCount: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes = + EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes: BigInt = EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) } @@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a double value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { - condition match { case And(cond1, cond2) => - (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) - match { + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update) + val percent2 = calculateFilterSelectivity(cond2, update) + (percent1, percent2) match { case (Some(p1), Some(p2)) => Some(p1 * p2) case (Some(p1), None) => Some(p1) case (None, Some(p2)) => Some(p2) @@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => None } - case _ => - calculateSingleCondition(condition, update) + case _ => calculateSingleCondition(condition, update) } } @@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition a single logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return Option[Double] value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { @@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. - // EqualTo does not care about the order - case op @ EqualTo(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ EqualTo(l: Literal, ar: AttributeReference) => - evaluateBinary(op, ar, l, update) + // EqualTo/EqualNullSafe does not care about the order + case op @ Equality(ar: Attribute, l: Literal) => + evaluateEquality(ar, l, update) + case op @ Equality(l: Literal, ar: Attribute) => + evaluateEquality(ar, l, update) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case op @ LessThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThan(l: Literal, ar: AttributeReference) => + case op @ LessThan(l: Literal, ar: Attribute) => evaluateBinary(GreaterThan(ar, l), ar, l, update) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ LessThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ LessThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case op @ GreaterThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThan(l: Literal, ar: AttributeReference) => + case op @ GreaterThan(l: Literal, ar: Attribute) => evaluateBinary(LessThan(ar, l), ar, l, update) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: AttributeReference, expList) + case In(ar: Attribute, expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. @@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: AttributeReference, set) => + case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = true, update) + case IsNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = false, update) + case IsNotNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = false, update) case _ => // TODO: it's difficult to support string operators without advanced statistics. @@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics collected for a given column. */ - def evaluateIsNull( - attrRef: AttributeReference, + def evaluateNullCheck( + attr: Attribute, isNull: Boolean, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.stats(catalystConf).rowCount.get - val nullPercent: BigDecimal = - if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + val colStat = colStatsMap(attr) + val rowCountValue = childStats.rowCount.get + val nullPercent: BigDecimal = if (rowCountValue == 0) { + 0 + } else { + BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + } if (update) { - val newStats = - if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) - else aColStat.copy(nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) + val newStats = if (isNull) { + colStat.copy(distinctCount = 0, min = None, max = None) + } else { + colStat.copy(nullCount = 0) + } + colStatsMap(attr) = newStats } - val percent = - if (isNull) { - nullPercent.toDouble - } - else { - /** ISNOTNULL(column) */ - 1.0 - nullPercent.toDouble - } + val percent = if (isNull) { + nullPercent.toDouble + } else { + 1.0 - nullPercent.toDouble + } Some(percent) } @@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinary( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - - op match { - case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + update: Boolean): Option[Double] = { + attr.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attr, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) + None case _ => - attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, attrRef, literal, update) - case StringType | BinaryType => - // TODO: It is difficult to support other binary comparisons for String/Binary - // type without min/max and advanced statistics like histogram. - logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) - None - } + // TODO: support boolean type. + None } } @@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) case TimestampType => Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case _: DecimalType => + Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) case StringType | BinaryType => None case _ => @@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateEqualTo( - attrRef: AttributeReference, + def evaluateEquality( + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) - val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) - - if (inBoundary) { - + val statsRange = Range(colStat.min, colStat.max, attr.dataType) + if (statsRange.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate. // Set distinctCount to 1. Set nullCount to 0. // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - val newStats = aColStat.copy(distinctCount = 1, min = newValue, + val newValue = convertBoundValue(attr.dataType, literal.value) + val newStats = colStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } Some(1.0 / ndv.toDouble) @@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param hSet a set of literal values * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateInSet( - attrRef: AttributeReference, + attr: Attribute, hSet: Set[Any], - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val aType = attrRef.dataType - var newNdv: Long = 0 + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + val dataType = attr.dataType + var newNdv = ndv // use [min, max] to filter the original hSet - aType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - - // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. - // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. - val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) - val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) - // We use hSetBigdecToAnyMap to help us find the original hSet value. - val hSetBigdecToAnyMap: Map[BigDecimal, Any] = - hSet.map(e => BigDecimal(e.toString) -> e).toMap + dataType match { + case _: NumericType | BooleanType | DateType | TimestampType => + val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val validQuerySet = hSet.filter { v => + v != null && statsRange.contains(Literal(v, dataType)) + } if (validQuerySet.isEmpty) { return Some(0.0) } // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) - val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + val newMax = convertBoundValue( + attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) + val newMin = convertBoundValue( + attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => - newNdv = math.min(hSet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + colStatsMap(attr) = newStats } } @@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * This method evaluate expression for Numeric columns only. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinaryForNumeric( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { + update: Boolean): Option[Double] = { var percent = 1.0 - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + val colStat = colStatsMap(attr) val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] // determine the overlapping degree between predicate range and column's range val literalValueBD = BigDecimal(literal.value.toString) @@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo percent = 1.0 } else { // this is partial overlap case - var newMax = aColStat.max - var newMin = aColStat.min - var newNdv = ndv - val literalToDouble = literalValueBD.toDouble - val maxToDouble = BigDecimal(statsRange.max).toDouble - val minToDouble = BigDecimal(statsRange.min).toDouble + val literalDouble = literalValueBD.toDouble + val maxDouble = BigDecimal(statsRange.max).toDouble + val minDouble = BigDecimal(statsRange.min).toDouble // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. // For ease of computation, we convert all relevant numeric values to Double. percent = op match { case _: LessThan => - (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + (literalDouble - minDouble) / (maxDouble - minDouble) case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble - else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.min)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (literalDouble - minDouble) / (maxDouble - minDouble) + } case _: GreaterThan => - (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + (maxDouble - literalDouble) / (maxDouble - minDouble) case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble - else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.max)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (maxDouble - literalDouble) / (maxDouble - minDouble) + } } - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - if (update) { + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attr.dataType, literal.value) + var newMax = colStat.max + var newMin = colStat.min op match { case _: GreaterThan => newMin = newValue case _: GreaterThanOrEqual => newMin = newValue @@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: LessThanOrEqual => newMax = newValue } - newNdv = math.max(math.round(ndv.toDouble * percent), 1) - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } } @@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } } + +class ColumnStatsMap { + private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty + + def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = { + baseMap.clear() + baseMap ++= colStats.baseMap + } + + def contains(a: Attribute): Boolean = baseMap.contains(a.exprId) + + def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2 + + def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats) + + def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 982a5a8bb89be..9782c0bb0a939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging case _ if !rowCountsExist(conf, join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val selectivity = joinSelectivity(joinKeyPairs) @@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (innerJoinedRows == 0) { + } else if (selectivity == 0) { joinType match { - // For outer joins, if the inner join part is empty, the number of output rows is the + // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side // keep unchanged, while column stats of join keys from the other side should be updated // based on added null values. @@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } case _ => Nil } + } else if (selectivity == 1) { + // Cartesian product, just propagate the original column stats + inputAttrStats.toSeq } else { val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { @@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - isBroadcastable = false)) + attributeStats = outputAttrStats)) case _ => // When there is no equi-join condition, we do estimation like cartesian product. @@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), rowCount = Some(outputRows), - attributeStats = inputAttrStats, - isBroadcastable = false)) + attributeStats = inputAttrStats)) } // scalastyle:off @@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } if (ndvDenom < 0) { - // There isn't join keys or column stats for any of the join key pairs, we do estimation like - // cartesian product. + // We can't find any join key pairs with column stats, estimate it as cartesian join. 1 } else if (ndvDenom == 0) { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. @@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging /** * Propagate or update column stats for output attributes. - * 1. For cartesian product, all values are preserved, so there's no need to change column stats. - * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update - * distinct count of other attributes based on output rows after join. */ private def updateAttrStats( outputRows: BigInt, @@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - if (outputRows == leftRows * rightRows) { - // Cartesian product, just propagate the original column stats - attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) - } else { - val leftRatio = - if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) - val rightRatio = - if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) - attributes.foreach { a => - // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats += a -> joinKeyStats(a) + } else { + val leftRatio = if (leftRows != 0) { + BigDecimal(outputRows) / BigDecimal(leftRows) + } else { + BigDecimal(0) + } + val rightRatio = if (rightRows != 0) { + BigDecimal(outputRows) / BigDecimal(rightRows) } else { - val oldColStat = oldAttrStats(a) - val oldNdv = oldColStat.distinctCount - // We only change (scale down) the number of distinct values if the number of rows - // decreases after join, because join won't produce new values even if the number of - // rows increases. - val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { - ceil(BigDecimal(oldNdv) * leftRatio) - } else if (join.right.outputSet.contains(a) && rightRatio < 1) { - ceil(BigDecimal(oldNdv) * rightRatio) - } else { - oldNdv - } - // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + BigDecimal(0) } + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) } + } outputAttrStats } @@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) - intersectedStats.put(leftKey, - leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) - intersectedStats.put(rightKey, - rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) + val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + intersectedStats.put(leftKey, newStats) + intersectedStats.put(rightKey, newStats) } AttributeMap(intersectedStats.toSeq) } @@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), rowCount = Some(outputRows), - attributeStats = leftStats.attributeStats, - isBroadcastable = false)) + attributeStats = leftStats.attributeStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 455711453272d..3d13967cb62a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} /** Value range of a column. */ -trait Range +trait Range { + def contains(l: Literal): Boolean +} /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range +case class NumericRange(min: JDecimal, max: JDecimal) extends Range { + override def contains(l: Literal): Boolean = { + val decimal = l.dataType match { + case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + case _ => new JDecimal(l.value.toString) + } + min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + } +} /** * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range +class DefaultRange extends Range { + override def contains(l: Literal): Boolean = true +} /** This is for columns with only null values. */ -class NullRange extends Range +class NullRange extends Range { + override def contains(l: Literal): Boolean = false +} object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { @@ -58,20 +72,6 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } - def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { - case _: DefaultRange => true - case _: NullRange => false - case n: NumericRange => - val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { - if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - } else { - assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || - lit.dataType.isInstanceOf[TimestampType]) - new JDecimal(lit.value.toString) - } - n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 - } - /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index 920c6ea50f4ba..f45a826869842 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types.{LongType, NullType} +import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in * end-to-end tests (in sql/core module) for verifying the correct error messages are shown * in negative cases. */ -class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { +class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) - assert(ResolveInlineTables(table) == table) + assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted = ResolveInlineTables.convert(table) + val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) @@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { assert(converted.data(1).getLong(0) == 2L) } + test("convert TimeZoneAwareExpression") { + val table = UnresolvedInlineTable(Seq("c1"), + Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) + val converted = ResolveInlineTables(conf).convert(table) + val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) + .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] + assert(converted.output.map(_.dataType) == Seq(TimestampType)) + assert(converted.data.size == 1) + assert(converted.data(0).getLong(0) == correct) + } + test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted1 = ResolveInlineTables.convert(table1) + val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) - val converted2 = ResolveInlineTables.convert(table2) + val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 44434324d3770..a755231962be2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -433,15 +433,15 @@ class SessionCatalogSuite extends PlanTest { sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") // If we explicitly specify the database, we'll look up the relation in that database - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))) - == SubqueryAlias("tbl1", SimpleCatalogRelation(metastoreTable1), None)) + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) == SubqueryAlias("tbl1", tempTable1, None)) // Then, if that does not exist, look up the relation in the current database sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", SimpleCatalogRelation(metastoreTable1), None)) + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) } test("look up view relation") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 82756f545a8c7..d128315b68869 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -130,6 +130,20 @@ class FoldablePropagationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Propagate in inner join") { + val ta = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('ta) + val tb = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('tb) + val query = ta.join(tb, Inner, + Some("ta.a".attr === "tb.a".attr && "ta.tag".attr === "tb.tag".attr)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + comparePlans(optimized, correctAnswer) + } + test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index f5e306f9e504d..8be74ced7bb71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.statsEstimation -import java.sql.{Date, Timestamp} +import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) + // only 2 values + val arBool = AttributeReference("cbool", BooleanType)() + val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1) + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") @@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through - // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). - val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") - val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") - val arTimestamp = AttributeReference("ctimestamp", TimestampType)() - val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), - nullCount = 0, avgLen = 8, maxLen = 8) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") @@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) + } + + test("cint <=> 2") { + validateEstimatedStats( + arInt, + Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) } test("cint = 0") { @@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint < 3") { @@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint < 0") { @@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint <= 3") { @@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint > 6") { @@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint > 10") { @@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint >= 6") { @@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint IS NULL") { @@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint IS NOT NULL") { @@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(10L) - ) + 10) } test("cint > 3 AND cint <= 6") { @@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), - Some(4L) - ) + 4) } test("cint = 3 OR cint = 6") { @@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } test("cint IN (3, 4, 5)") { @@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint NOT IN (3, 4, 5)") { @@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(7L) - ) + 7) + } + + test("cbool = true") { + validateEstimatedStats( + arBool, + Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 5) + } + + test("cbool > false") { + // bool comparison is not supported yet, so stats remain same. + validateEstimatedStats( + arBool, + Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 10) } test("cdate = cast('2017-01-02' AS DATE)") { @@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) } test("cdate < cast('2017-01-03' AS DATE)") { @@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) - } - - test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { - val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") - validateEstimatedStats( - arTimestamp, - Filter(EqualTo(arTimestamp, Literal(ts2017010102)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) - } - - test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { - val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") - validateEstimatedStats( - arTimestamp, - Filter(LessThan(arTimestamp, Literal(ts2017010103)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdecimal = 0.400000000000000000") { @@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) + 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + Filter(LessThan(arDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdouble < 3.0") { @@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cstring = 'A2'") { @@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(1L) - ) + 1) } // There is no min/max statistics for String type. We estimate 10 rows returned. @@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(10L) - ) + 10) } // This is a corner test case. We want to test if we can handle the case when the number of @@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { @@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = tableRowCount, attributeStats = AttributeMap(Seq( arInt -> childColStatInt, + arBool -> childColStatBool, arDate -> childColStatDate, - arTimestamp -> childColStatTimestamp, arDecimal -> childColStatDecimal, arDouble -> childColStatDouble, arString -> childColStatString @@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ar: AttributeReference, filterNode: Filter, expectedColStats: ColumnStat, - rowCount: Option[BigInt] = None) - : Unit = { + rowCount: Int): Unit = { - val expectedRowCount: BigInt = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) + val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) val filteredStats = filterNode.stats(conf) assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount == rowCount) - ar.dataType match { - case DecimalType() => - // Due to the internal transformation for DecimalType within engine, the new min/max - // in ColumnStat may have a different structure even it contains the right values. - // We convert them to Java BigDecimal values so that we can compare the entire object. - val generatedColumnStats = filteredStats.attributeStats(ar) - val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) - val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) - val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) - assert(outputColStats == expectedColStats) - case _ => - // For all other SQL types, we compare the entire object directly. - assert(filteredStats.attributeStats(ar) == expectedColStats) - } + assert(filteredStats.rowCount.get == rowCount) + assert(filteredStats.attributeStats(ar) == expectedColStats) // If the filter has a binary operator (including those nested inside // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the // operator, and then check again. val rewrittenFilter = filterNode transformExpressionsDown { - case op @ EqualTo(ar: AttributeReference, l: Literal) => + case EqualTo(ar: AttributeReference, l: Literal) => EqualTo(l, ar) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case LessThan(ar: AttributeReference, l: Literal) => GreaterThan(l, ar) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case LessThanOrEqual(ar: AttributeReference, l: Literal) => GreaterThanOrEqual(l, ar) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case GreaterThan(ar: AttributeReference, l: Literal) => LessThan(l, ar) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => LessThanOrEqual(l, ar) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 59baf6e567721..41470ae6aae19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -263,8 +263,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Loads a JSON file and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 393925161fc7b..49e85dc7b13f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -349,8 +349,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { // Get all input data source or hive relations of the query. val srcRelations = df.logicalPlan.collect { case LogicalRelation(src: BaseRelation, _, _) => src - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.catalogTable) => - relation.catalogTable.identifier + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) => + relation.tableMeta.identifier } val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed @@ -360,8 +360,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") // check hive table relation when overwrite mode - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.catalogTable) - && srcRelations.contains(relation.catalogTable.identifier) => + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) + && srcRelations.contains(relation.tableMeta.identifier) => throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") case _ => // OK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index e56c33e4b512f..a4c5bf756cd5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,12 +47,14 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport + && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { if (enableHiveSupport) { logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - "Spark is not built with Hive; falling back to without Hive support.") + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + + "falling back to without Hive support.") } SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 4ca1347008575..80138510dc9ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -168,15 +168,16 @@ class CacheManager extends Logging { (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } - cachedData.foreach { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - case _ => // Do Nothing + cachedData.filter { + case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true + case _ => false + }.foreach { data => + val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) + if (dataIndex >= 0) { + data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) + cachedData.remove(dataIndex) + } + sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index b8ac070e3a959..aa578f4d23133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -102,12 +102,14 @@ case class OptimizeMetadataOnlyQuery( LocalRelation(partAttrs, partitionData.map(_.values)) case relation: CatalogRelation => - val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) - val partitionData = catalog.listPartitions(relation.catalogTable.identifier).map { p => + val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) + val caseInsensitiveProperties = + CaseInsensitiveMap(relation.tableMeta.storage.properties) + val timeZoneId = caseInsensitiveProperties.get("timeZone") + .getOrElse(conf.sessionLocalTimeZone) + val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - // TODO: use correct timezone for partition values. - Cast(Literal(p.spec(attr.name)), attr.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } LocalRelation(partAttrs, partitionData) @@ -135,8 +137,8 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) Some(AttributeSet(partAttrs), l) - case relation: CatalogRelation if relation.catalogTable.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) + case relation: CatalogRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) Some(AttributeSet(partAttrs), relation) case p @ Project(projectList, child) if projectList.forall(_.deterministic) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index d024a3673d4ba..b89014ed8ef54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -40,60 +37,40 @@ case class AnalyzeColumnCommand( val sessionState = sparkSession.sessionState val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = - EliminateSubqueryAliases(sparkSession.table(tableIdentWithDB).queryExecution.analyzed) - - // Compute total size - val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { - case catalogRel: CatalogRelation => - // This is a Hive serde format table - (catalogRel.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) - - case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - // This is a data source format table - (logicalRel.catalogTable.get, - AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) - - case otherRelation => - throw new AnalysisException("ANALYZE TABLE is not supported for " + - s"${otherRelation.nodeName}.") + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") } + val sizeInBytes = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) // Compute stats for each column - val (rowCount, newColStats) = - AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames) + val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = CatalogStatistics( sizeInBytes = sizeInBytes, rowCount = Some(rowCount), // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + sessionState.catalog.alterTable(tableMeta.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } -} - -object AnalyzeColumnCommand extends Logging { /** * Compute stats for the given columns. * @return (row count, map from column name to ColumnStats) - * - * This is visible for testing. */ - def computeColumnStats( + private def computeColumnStats( sparkSession: SparkSession, - tableName: String, - relation: LogicalPlan, + tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = AttributeSet(columnNames.map { col => @@ -105,7 +82,7 @@ object AnalyzeColumnCommand extends Logging { attributesToAnalyze.foreach { attr => if (!ColumnStat.supportsType(attr.dataType)) { throw new AnalysisException( - s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " + + s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") } } @@ -116,7 +93,7 @@ object AnalyzeColumnCommand extends Logging { // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 30b6cc7617cb3..d2ea0cdf61aa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -22,11 +22,9 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SessionState @@ -41,53 +39,39 @@ case class AnalyzeTableCommand( val sessionState = sparkSession.sessionState val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = - EliminateSubqueryAliases(sparkSession.table(tableIdentWithDB).queryExecution.analyzed) - - relation match { - case relation: CatalogRelation => - updateTableStats(relation.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) - - // data source tables have been converted into LogicalRelations - case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateTableStats(logicalRel.catalogTable.get, - AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) - - case otherRelation => - throw new AnalysisException("ANALYZE TABLE is not supported for " + - s"${otherRelation.nodeName}.") + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") } + val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) - def updateTableStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val oldTotalSize = catalogTable.stats.map(_.sizeInBytes.toLong).getOrElse(0L) - val oldRowCount = catalogTable.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) - var newStats: Option[CatalogStatistics] = None - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) - } - // We only set rowCount when noscan is false, because otherwise: - // 1. when total size is not changed, we don't need to alter the table; - // 2. when total size is changed, `oldRowCount` becomes invalid. - // This is to make sure that we only record the right statistics. - if (!noscan) { - val newRowCount = Dataset.ofRows(sparkSession, relation).count() - if (newRowCount >= 0 && newRowCount != oldRowCount) { - newStats = if (newStats.isDefined) { - newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) - } else { - Some(CatalogStatistics( - sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount)))) - } + val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) + val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + var newStats: Option[CatalogStatistics] = None + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) + } + // We only set rowCount when noscan is false, because otherwise: + // 1. when total size is not changed, we don't need to alter the table; + // 2. when total size is changed, `oldRowCount` becomes invalid. + // This is to make sure that we only record the right statistics. + if (!noscan) { + val newRowCount = sparkSession.table(tableIdentWithDB).count() + if (newRowCount >= 0 && newRowCount != oldRowCount) { + newStats = if (newStats.isDefined) { + newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) + } else { + Some(CatalogStatistics( + sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount)))) } } - // Update the metastore if the above statistics of the table are different from those - // recorded in the metastore. - if (newStats.isDefined) { - sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + } + // Update the metastore if the above statistics of the table are different from those + // recorded in the metastore. + if (newStats.isDefined) { + sessionState.catalog.alterTable(tableMeta.copy(stats = newStats)) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 5abd579476504..d835b521166a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -141,7 +141,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, mode, tableExists = true) + sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -151,7 +151,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, mode, tableExists = false) + sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 49407b44d7b8a..3e80916104bd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -269,8 +269,8 @@ case class LoadDataCommand( } else { // Follow Hive's behavior: // If no schema or authority is provided with non-local inpath, - // we will use hadoop configuration "fs.default.name". - val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.default.name") + // we will use hadoop configuration "fs.defaultFS". + val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") val defaultFS = if (defaultFSConf == null) { new URI("") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 1235a4b12f1d0..2068811661fec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -72,7 +72,8 @@ class CatalogFileIndex( val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( - p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone), + path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 0762d1b7daaea..54549f698aca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils +import org.apache.spark.TaskContext + object CodecStreams { private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { val compressionCodecs = new CompressionCodecFactory(config) @@ -42,6 +44,16 @@ object CodecStreams { .getOrElse(inputStream) } + /** + * Creates an input stream from the string path and add a closure for the input stream to be + * closed on task completion. + */ + def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { + val inputStream = createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d510581f90e69..4947dfda6fc7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -106,10 +106,13 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource + * @param fileStatusCache the shared cache for file statuses to speed up listing * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ - private def getOrInferFileFormatSchema(format: FileFormat): (StructType, StructType) = { + private def getOrInferFileFormatSchema( + format: FileFormat, + fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { // the operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below @@ -122,7 +125,7 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) } val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning @@ -279,28 +282,6 @@ case class DataSource( } } - /** - * Returns true if there is a single path that has a metadata log indicating which files should - * be read. - */ - def hasMetadata(path: Seq[String]): Boolean = { - path match { - case Seq(singlePath) => - try { - val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) - val res = fs.exists(metadataPath) - res - } catch { - case NonFatal(e) => - logWarning(s"Error while looking for metadata directory.") - false - } - case _ => false - } - } - /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] @@ -331,7 +312,9 @@ case class DataSource( // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. case (format: FileFormat, _) - if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => + if FileStreamSink.hasMetadata( + caseInsensitiveOptions.get("path").toSeq ++ paths, + sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { @@ -374,7 +357,8 @@ case class DataSource( globPath }.toArray - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -384,7 +368,8 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) } else { - new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(partitionSchema)) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) } HadoopFsRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f4292320e4bfe..f694a0d6d724b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -208,16 +208,17 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { /** - * Replaces [[SimpleCatalogRelation]] with data source table if its table provider is not hive. + * Replaces [[CatalogRelation]] with data source table if its table provider is not hive. */ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private def readDataSourceTable(table: CatalogTable): LogicalPlan = { + private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { + val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache val withHiveSupport = sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" - cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> _) val dataSource = @@ -233,19 +234,25 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] // TODO: improve `InMemoryCatalog` and remove this limitation. catalogTable = if (withHiveSupport) Some(table) else None) - LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), + LogicalRelation( + dataSource.resolveRelation(checkFilesExist = false), catalogTable = Some(table)) } - }) + }).asInstanceOf[LogicalRelation] + + // It's possible that the table schema is empty and need to be inferred at runtime. We should + // not specify expected outputs for this case. + val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) + plan.copy(expectedOutputAttributes = expectedOutputs) } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) - if DDLUtils.isDatasourceTable(s.metadata) => - i.copy(table = readDataSourceTable(s.metadata)) + case i @ InsertIntoTable(r: CatalogRelation, _, _, _, _) + if DDLUtils.isDatasourceTable(r.tableMeta) => + i.copy(table = readDataSourceTable(r)) - case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => - readDataSourceTable(s.metadata) + case r: CatalogRelation if DDLUtils.isDatasourceTable(r.tableMeta) => + readDataSourceTable(r) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 644358493e2eb..950e5ca0d6210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -68,7 +68,8 @@ object FileFormatWriter extends Logging { val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val timeZoneId: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -122,9 +123,11 @@ object FileFormatWriter extends Logging { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } + val caseInsensitiveOptions = CaseInsensitiveMap(options) + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, @@ -136,8 +139,10 @@ object FileFormatWriter extends Logging { bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, - maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) // We should first sort by partition columns, then bucket id, and finally sorting columns. @@ -210,11 +215,11 @@ object FileFormatWriter extends Logging { val taskAttemptContext: TaskAttemptContext = { // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value - hadoopConf.set("mapred.job.id", jobId.toString) - hadoopConf.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - hadoopConf.set("mapred.task.id", taskAttemptId.toString) - hadoopConf.setBoolean("mapred.task.is.map", true) - hadoopConf.setInt("mapred.task.partition", 0) + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -330,11 +335,10 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), + Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 7531f0ae02e75..ee4d0863d9771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -66,8 +66,8 @@ class InMemoryFileIndex( } override def refresh(): Unit = { - refresh0() fileStatusCache.invalidateAll() + refresh0() } private def refresh0(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 652bcc8331936..19b51d4d9530a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -147,7 +147,10 @@ case class InsertIntoHadoopFsRelationCommand( refreshFunction = refreshPartitionsCallback, options = options) + // refresh cached files in FileIndex fileIndex.foreach(_.refresh()) + // refresh data cache if table is cached + sparkSession.catalog.refreshByPath(outputPath.toString) } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 75f87a5503b8c..c8097a7fabc2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,7 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -125,22 +125,27 @@ abstract class PartitioningAwareFileIndex( val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => files.exists(f => isDataPath(f.getPath)) }.keys.toSeq + + val caseInsensitiveOptions = CaseInsensitiveMap(parameters) + val timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) + userPartitionSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, typeInference = false, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => - // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Option(timeZoneId)).eval() }: _*) } @@ -151,7 +156,8 @@ abstract class PartitioningAwareFileIndex( PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) } } @@ -300,7 +306,7 @@ object PartitioningAwareFileIndex extends Logging { sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size < sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { return paths.map { path => (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bad59961ace12..09876bbc2f85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -91,10 +93,19 @@ object PartitioningUtils { private[datasources] def parsePartitions( paths: Seq[Path], typeInference: Boolean, - basePaths: Set[Path]): PartitionSpec = { + basePaths: Set[Path], + timeZoneId: String): PartitionSpec = { + parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId)) + } + + private[datasources] def parsePartitions( + paths: Seq[Path], + typeInference: Boolean, + basePaths: Set[Path], + timeZone: TimeZone): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths) + parsePartition(path, typeInference, basePaths, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +184,8 @@ object PartitioningUtils { private[datasources] def parsePartition( path: Path, typeInference: Boolean, - basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { + basePaths: Set[Path], + timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -194,7 +206,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -226,7 +238,8 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - typeInference: Boolean): Option[(String, Literal)] = { + typeInference: Boolean, + timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -237,7 +250,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) Some(columnName -> literal) } } @@ -370,7 +383,8 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - typeInference: Boolean): Literal = { + typeInference: Boolean, + timeZone: TimeZone): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. val bigDecimal = new JBigDecimal(raw) @@ -390,8 +404,16 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try(Literal(JDate.valueOf(raw)))) - .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) + .orElse(Try( + Literal.create( + DateTimeUtils.getThreadLocalTimestampFormat(timeZone) + .parse(unescapePathName(raw)).getTime * 1000L, + TimestampType))) + .orElse(Try( + Literal.create( + DateTimeUtils.millisToDays( + DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), + DateType))) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala new file mode 100644 index 0000000000000..73e6abc6dad37 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.InputStream +import java.nio.charset.{Charset, StandardCharsets} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.types.StructType + +/** + * Common functions for parsing CSV files + */ +abstract class CSVDataSource extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into [[InternalRow]] instances. + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] + + /** + * Infers the schema from `inputPaths` files. + */ + def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + protected def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } +} + +object CSVDataSource { + def apply(options: CSVOptions): CSVDataSource = { + if (options.wholeFile) { + WholeFileCSVDataSource + } else { + TextInputCSVDataSource + } + } +} + +object TextInputCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = true + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + val lines = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + } + } + + val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) + } else { + val charset = options.charset + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(",")) + .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) + } + } +} + +object WholeFileCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = false + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + UnivocityParser.parseStream( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + parsedOptions.headerFlag, + parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) + val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + false, + new CsvParser(parsedOptions.asParserSettings)) + }.take(1).headOption + + if (maybeFirstRow.isDefined) { + val firstRow = maybeFirstRow.get + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } else { + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) + } + } + + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): RDD[PortableDataStream] = { + val paths = inputPaths.map(_.getPath) + val name = paths.mkString(",") + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + val conf = job.getConfiguration + + val rdd = new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + + // Only returns `PortableDataStream`s without paths. + rdd.setName(s"CSVFile: $name").values + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 59f2919edfe2e..29c41455279e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.{Charset, StandardCharsets} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -43,11 +37,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" - override def toString: String = "CSV" - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvDataSource = CSVDataSource(parsedOptions) + csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, @@ -55,11 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { files: Seq[FileStatus]): Option[StructType] = { require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val paths = files.map(_.getPath.toString) - val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) } override def prepareWrite( @@ -115,49 +112,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val lines = { - val conf = broadcastedHadoopConf.value.value - val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) - } - } - - val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, parsedOptions) - } else { - lines - } - - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions) + val conf = broadcastedHadoopConf.value.value val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - filteredLines.flatMap(parser.parse) + CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) } } - private def createBaseDataset( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[String]): Dataset[String] = { - if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = inputPaths, - className = classOf[TextFileFormat].getName - ).resolveRelation(checkFilesExist = false)) - .select("value").as[String](Encoders.STRING) - } else { - val charset = options.charset - val rdd = sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) - .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) - sparkSession.createDataset(rdd)(Encoders.STRING) - } - } + override def toString: String = "CSV" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3fa30fe2401e1..b64d71bb4eef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -21,11 +21,9 @@ import java.math.BigDecimal import scala.util.control.Exception._ -import com.univocity.parsers.csv.CsvParser - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ private[csv] object CSVInferSchema { @@ -37,24 +35,13 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - csv: Dataset[String], - caseSensitive: Boolean, + tokenRDD: RDD[Array[String]], + header: Array[String], options: CSVOptions): StructType = { - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first() - val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) - val header = makeSafeHeader(firstRow, caseSensitive, options) - val fields = if (options.inferSchemaFlag) { - val tokenRdd = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options) - val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options) - val parser = new CsvParser(options.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { @@ -71,44 +58,6 @@ private[csv] object CSVInferSchema { StructType(fields) } - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } - private def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1caeec7c63945..50503385ad6d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -130,6 +130,8 @@ private[csv] class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index eb471651db2e3..3b3b87e4354d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import java.io.InputStream import java.math.BigDecimal import java.text.NumberFormat import java.util.Locale @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] class UnivocityParser( schema: StructType, requiredSchema: StructType, - options: CSVOptions) extends Logging { + private val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -53,36 +54,77 @@ private[csv] class UnivocityParser( private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val valueConverters = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val parser = new CsvParser(options.asParserSettings) + private val tokenizer = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 private val row = new GenericInternalRow(requiredSchema.length) - // This parser loads an `indexArr._1`-th position value in input tokens, - // then put the value in `row(indexArr._2)`. - private val indexArr: Array[(Int, Int)] = { - val fields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } - // TODO: Revisit this; we need to clean up code here for readability. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102636720 - val fieldsWithIndexes = fields.zipWithIndex - corruptFieldIndex.map { case corrFieldIndex => - fieldsWithIndexes.filter { case (_, i) => i != corrFieldIndex } - }.getOrElse { - fieldsWithIndexes - }.map { case (f, i) => - (dataSchema.indexOf(f), i) - }.toArray + // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field + // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. + private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + + // This parser loads an `tokenIndexArr`-th position value in input tokens, + // then put the value in `row(rowIndexArr)`. + // + // For example, let's say there is CSV data as below: + // + // a,b,c + // 1,2,A + // + // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` + // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need + // to map those values below: + // + // required schema - ["c", "b", "_unparsed", "a"] + // CSV data schema - ["a", "b", "c"] + // required CSV data schema - ["c", "b", "a"] + // + // with the input tokens, + // + // input tokens - [1, 2, "A"] + // + // Each input token is placed in each output row's position by mapping these. In this case, + // + // output row - ["A", 2, null, 1] + // + // In more details, + // - `valueConverters`, input tokens - CSV data schema + // `valueConverters` keeps the positions of input token indices (by its index) to each + // value's converter (by its value) in an order of CSV data schema. In this case, + // [string->int, string->int, string->string]. + // + // - `tokenIndexArr`, input tokens - required CSV data schema + // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered + // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. + // + // - `rowIndexArr`, input tokens - required schema + // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered + // field indices given the required schema (by its value). In this case, [0, 1, 3]. + private val valueConverters: Array[ValueConverter] = + dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means + // the fields that we should try to convert. + private val reorderedFields = if (options.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) + } else { + requiredSchema + } + + private val tokenIndexArr: Array[Int] = { + reorderedFields + .filter(_.name != options.columnNameOfCorruptRecord) + .map(f => dataSchema.indexOf(f)).toArray + } + + private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { + val corrFieldIndex = corruptFieldIndex.get + reorderedFields.indices.filter(_ != corrFieldIndex).toArray + } else { + reorderedFields.indices.toArray } /** @@ -188,21 +230,23 @@ private[csv] class UnivocityParser( } /** - * Parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the + * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = { - convertWithParseMode(input) { tokens => + def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): Option[InternalRow] = { + convertWithParseMode(tokens) { tokens => var i: Int = 0 - while (i < indexArr.length) { - val (pos, rowIdx) = indexArr(i) + while (i < tokenIndexArr.length) { // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. - val value = valueConverters(pos).apply(tokens(pos)) + val from = tokenIndexArr(i) + val to = rowIndexArr(i) + val value = valueConverters(from).apply(tokens(from)) if (i < requiredSchema.length) { - row(rowIdx) = value + row(to) = value } i += 1 } @@ -211,8 +255,7 @@ private[csv] class UnivocityParser( } private def convertWithParseMode( - input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { - val tokens = parser.parseLine(input) + tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") @@ -251,7 +294,7 @@ private[csv] class UnivocityParser( } catch { case NonFatal(e) if options.permissive => val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input)) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { @@ -269,3 +312,75 @@ private[csv] class UnivocityParser( } } } + +private[csv] object UnivocityParser { + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of tokens. + */ + def tokenizeStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser): Iterator[Array[String]] = { + convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) + } + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of rows. + */ + def parseStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val tokenizer = parser.tokenizer + convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + parser.convert(tokens) + }.flatten + } + + private def convertStream[T]( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer.beginParsing(inputStream) + private var nextRecord = { + if (shouldDropHeader) { + tokenizer.parseNext() + } + tokenizer.parseNext() + } + + override def hasNext: Boolean = nextRecord != null + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + val curRecord = convert(nextRecord) + nextRecord = tokenizer.parseNext() + curRecord + } + } + + /** + * Parses an iterator that contains CSV strings and turns it into an iterator of rows. + */ + def parseIterator( + lines: Iterator[String], + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val options = parser.options + + val linesWithoutHeader = if (shouldDropHeader) { + // Note that if there are only comments in the first block, the header would probably + // be not dropped. + CSVUtils.dropHeaderLine(lines, options) + } else { + lines + } + + val filteredLines: Iterator[String] = + CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + filteredLines.flatMap(line => parser.parse(line)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3e984effcb8d8..18843bfc307b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.InputStream - import scala.reflect.ClassTag import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} @@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { } } - private def createInputStream(config: Configuration, path: String): InputStream = { - val inputStream = CodecStreams.createInputStream(config, new Path(path)) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) - inputStream - } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, - createInputStream(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) } override def readFile( @@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { - Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + Utils.tryWithResource { + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } parser.parse( - createInputStream(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), CreateJacksonParser.inputStream, partitionedFileString).toIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index e7a59d4ad4dd2..4d781b96abace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -379,7 +379,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: CatalogRelation => - val metadata = relation.catalogTable + val metadata = relation.tableMeta preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) case LogicalRelation(h: HadoopFsRelation, _, catalogTable) => val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 0dbe2a71ed3bc..07ec4e9429e42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -25,9 +28,31 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} -object FileStreamSink { +object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. val metadataDir = "_spark_metadata" + + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = { + path match { + case Seq(singlePath) => + try { + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(hadoopConf) + val metadataPath = new Path(hdfsPath, metadataDir) + val res = fs.exists(metadataPath) + res + } catch { + case NonFatal(e) => + logWarning(s"Error while looking for metadata directory.") + false + } + case _ => false + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 39c0b4979687b..6a7263ca45d85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -43,8 +43,10 @@ class FileStreamSource( private val sourceOptions = new FileStreamOptions(options) + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val qualifiedBasePath: Path = { - val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = new Path(path).getFileSystem(hadoopConf) fs.makeQualified(new Path(path)) // can contains glob patterns } @@ -157,14 +159,65 @@ class FileStreamSource( checkFilesExist = false))) } + /** + * If the source has a metadata log indicating which files should be read, then we should use it. + * Only when user gives a non-glob path that will we figure out whether the source has some + * metadata log + * + * None means we don't know at the moment + * Some(true) means we know for sure the source DOES have metadata + * Some(false) means we know for sure the source DOSE NOT have metadata + */ + @volatile private[sql] var sourceHasMetadata: Option[Boolean] = + if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None + + private def allFilesUsingInMemoryFileIndex() = { + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) + fileIndex.allFiles() + } + + private def allFilesUsingMetadataLogFileIndex() = { + // Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a + // non-glob path + new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles() + } + /** * Returns a list of files found, sorted by their timestamp. */ private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime - val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) - val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) - val files = catalog.allFiles().sortBy(_.getModificationTime)(fileSortOrder).map { status => + + var allFiles: Seq[FileStatus] = null + sourceHasMetadata match { + case None => + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + allFiles = allFilesUsingInMemoryFileIndex() + if (allFiles.isEmpty) { + // we still cannot decide + } else { + // decide what to use for future rounds + // double check whether source has metadata, preventing the extreme corner case that + // metadata log and data files are only generated after the previous + // `FileStreamSink.hasMetadata` check + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + sourceHasMetadata = Some(false) + // `allFiles` have already been fetched using InMemoryFileIndex in this round + } + } + } + case Some(true) => allFiles = allFilesUsingMetadataLogFileIndex() + case Some(false) => allFiles = allFilesUsingInMemoryFileIndex() + } + + val files = allFiles.sortBy(_.getModificationTime)(fileSortOrder).map { status => (status.getPath.toUri.toString, status.getModificationTime) } val endTime = System.nanoTime diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 4bd6431cbe110..70912d13ae458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{InterruptedIOException, IOException} import java.util.UUID import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -37,6 +38,12 @@ import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} +/** States for [[StreamExecution]]'s lifecycle. */ +trait State +case object INITIALIZING extends State +case object ACTIVE extends State +case object TERMINATED extends State + /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any @@ -298,7 +305,14 @@ class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case _: InterruptedException if state.get == TERMINATED => // interrupted by stop() + case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED => + // interrupted by stop() + updateStatusMessage("Stopped") + case e: IOException if e.getMessage != null + && e.getMessage.startsWith(classOf[InterruptedException].getName) + && state.get == TERMINATED => + // This is a workaround for HADOOP-12074: `Shell.runCommand` converts `InterruptedException` + // to `new IOException(ie.toString())` before Hadoop 2.8. updateStatusMessage("Stopped") case e: Throwable => streamDeathCause = new StreamingQueryException( @@ -321,6 +335,7 @@ class StreamExecution( initializationLatch.countDown() try { + stopSources() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -558,6 +573,18 @@ class StreamExecution( sparkSession.streams.postListenerEvent(event) } + /** Stops all streaming sources safely. */ + private def stopSources(): Unit = { + uniqueSources.foreach { source => + try { + source.stop() + } catch { + case NonFatal(e) => + logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e) + } + } + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -570,7 +597,6 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } - uniqueSources.foreach(_.stop()) logInfo(s"Query $prettyIdString was stopped") } @@ -709,10 +735,6 @@ class StreamExecution( } } - trait State - case object INITIALIZING extends State - case object ACTIVE extends State - case object TERMINATED extends State } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index a2153d27e9fef..4207013c3f75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -75,6 +75,19 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Override the parent `postToAll` to remove the query id from `activeQueryRunIds` after all + * the listeners process `QueryTerminatedEvent`. (SPARK-19594) + */ + override def postToAll(event: Event): Unit = { + super.postToAll(event) + event match { + case t: QueryTerminatedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds -= t.runId } + case _ => + } + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => @@ -112,7 +125,6 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) case queryTerminated: QueryTerminatedEvent => if (shouldReport(queryTerminated.runId)) { listener.onQueryTerminated(queryTerminated) - activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } } case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 900d92bc0d959..58bff27a05bf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -46,8 +46,8 @@ object TextSocketSource { * support for fault recovery and keeping all of the text read in memory forever. */ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging -{ + extends Source with Logging { + @GuardedBy("this") private var socket: Socket = null @@ -168,6 +168,8 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo socket = null } } + + override def toString: String = s"TextSocketSource[host: $host, port: $port]" } class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 61eb601a18c32..ab1204a750fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -274,7 +274,18 @@ private[state] class HDFSBackedStateStoreProvider( private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { synchronized { val finalDeltaFile = deltaFile(newVersion) - if (!fs.rename(tempDeltaFile, finalDeltaFile)) { + + // scalastyle:off + // Renaming a file atop an existing one fails on HDFS + // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). + // Hence we should either skip the rename step or delete the target file. Because deleting the + // target file will break speculation, skipping the rename step is the only choice. It's still + // semantically correct because Structured Streaming requires rerunning a batch should + // generate the same output. (SPARK-19677) + // scalastyle:on + if (fs.exists(finalDeltaFile)) { + fs.delete(tempDeltaFile, true) + } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") } loadedMaps.put(newVersion, map) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc0f130406932..461dfe3a66e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -402,11 +402,13 @@ object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") - .doc("The maximum number of files allowed for listing files at driver side. If the number " + - "of detected files exceeds this value during partition discovery, it tries to list the " + + .doc("The maximum number of paths allowed for listing files at driver side. If the number " + + "of detected paths exceeds this value during partition discovery, it tries to list the " + "files with another Spark distributed job. This applies to Parquet, ORC, CSV, JSON and " + "LibSVM data sources.") .intConf + .checkValue(parallel => parallel >= 0, "The maximum number of paths allowed for listing " + + "files at driver side must not be negative") .createWithDefault(32) val PARALLEL_PARTITION_DISCOVERY_PARALLELISM = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f78e73f319de7..aed8074a64d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -143,8 +143,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Loads a JSON file stream and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -261,6 +261,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf8ff61eae39e..eb4d76c6ab032 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; @@ -146,13 +147,13 @@ public void dataFrameRDDOperations() { @Test public void applySchemaToJSON() { - JavaRDD jsonRDD = jsc.parallelize(Arrays.asList( + Dataset jsonDS = spark.createDataset(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + - "\"boolean\":false, \"null\":null}")); + "\"boolean\":false, \"null\":null}"), Encoders.STRING()); List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); @@ -183,14 +184,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - Dataset df1 = spark.read().json(jsonRDD); + Dataset df1 = spark.read().json(jsonDS); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.createOrReplaceTempView("jsonTable1"); List actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset df2 = spark.read().schema(expectedSchema).json(jsonRDD); + Dataset df2 = spark.read().schema(expectedSchema).json(jsonDS); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.createOrReplaceTempView("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a8f814bfae530..be8d95d0d9124 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -414,4 +414,13 @@ public void testBeanWithoutGetter() { Assert.assertEquals(df.schema().length(), 0); Assert.assertEquals(df.collectAsList().size(), 1); } + + @Test + public void testJsonRDDToDataFrame() { + // This is a test for the deprecated API in SPARK-15615. + JavaRDD rdd = jsc.parallelize(Arrays.asList("{\"a\": 2}")); + Dataset df = spark.read().json(rdd); + Assert.assertEquals(1L, df.count()); + Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0)); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java index d3769a74b9789..539976d5af469 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java @@ -88,7 +88,7 @@ public Encoder outputEncoder() { @Test public void testTypedAggregationAverage() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.avg(value -> (double)(value._2() * 2))); + Dataset> agged = grouped.agg(typed.avg(value -> value._2() * 2.0)); Assert.assertEquals( Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), agged.collectAsList()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 4581c6ebe9ef8..e3b0e37ccab05 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -110,7 +110,8 @@ public void testCommonOperation() { Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map((MapFunction) v -> v.length(), Encoders.INT()); + Dataset mapped = + ds.map((MapFunction) String::length, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { @@ -157,17 +158,17 @@ public void testReduce() { public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = spark.createDataset(data, Encoders.STRING()); - KeyValueGroupedDataset grouped = ds.groupByKey( - (MapFunction) v -> v.length(), - Encoders.INT()); + KeyValueGroupedDataset grouped = + ds.groupByKey((MapFunction) String::length, Encoders.INT()); - Dataset mapped = grouped.mapGroups((MapGroupsFunction) (key, values) -> { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return sb.toString(); - }, Encoders.STRING()); + Dataset mapped = grouped.mapGroups( + (MapGroupsFunction) (key, values) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); @@ -209,7 +210,8 @@ public void testGroupBy() { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); - Dataset> reduced = grouped.reduceGroups((ReduceFunction) (v1, v2) -> v1 + v2); + Dataset> reduced = + grouped.reduceGroups((ReduceFunction) (v1, v2) -> v1 + v2); Assert.assertEquals( asSet(tuple2(1, "a"), tuple2(3, "foobar")), diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java index 6941c86dfcd4b..127d272579a62 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java @@ -29,8 +29,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -40,7 +38,6 @@ public class JavaSaveLoadSuite { private transient SparkSession spark; - private transient JavaSparkContext jsc; File path; Dataset df; @@ -58,7 +55,6 @@ public void setUp() throws IOException { .master("local[*]") .appName("testing") .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -70,8 +66,8 @@ public void setUp() throws IOException { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = jsc.parallelize(jsonObjects); - df = spark.read().json(rdd); + Dataset ds = spark.createDataset(jsonObjects, Encoders.STRING()); + df = spark.read().json(ds); df.createOrReplaceTempView("jsonTable"); } diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql new file mode 100644 index 0000000000000..1caa45c66749d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql @@ -0,0 +1,36 @@ +-- Negative testcases for column resolution +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +-- Negative tests: column resolution scenarios with ambiguous cases in join queries +SET spark.sql.crossJoin.enabled = true; +USE mydb1; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +USE mydb2; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +SELECT db1.t1.i1 FROM t1, mydb2.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Negative tests +USE mydb1; +SELECT mydb1.t1 FROM t1; +SELECT t1.x.y.* FROM t1; +SELECT t1 FROM mydb1.t1; +USE mydb2; +SELECT mydb1.t1.i1 FROM t1; + +-- reset +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql new file mode 100644 index 0000000000000..d3f928751757c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -0,0 +1,25 @@ +-- Tests for qualified column names for the view code-path +-- Test scenario with Temporary view +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1; +SELECT view1.* FROM view1; +SELECT * FROM view1; +SELECT view1.i1 FROM view1; +SELECT i1 FROM view1; +SELECT a.i1 FROM view1 AS a; +SELECT i1 FROM view1 AS a; +-- cleanup +DROP VIEW view1; + +-- Test scenario with Global Temp view +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; +SELECT * FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.* FROM global_temp.view1; +SELECT i1 FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.i1 FROM global_temp.view1; +SELECT view1.i1 FROM global_temp.view1; +SELECT a.i1 FROM global_temp.view1 AS a; +SELECT i1 FROM global_temp.view1 AS a; +-- cleanup +DROP VIEW global_temp.view1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql new file mode 100644 index 0000000000000..79e90ad3de91d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -0,0 +1,88 @@ +-- Tests covering different scenarios with qualified column names +-- Scenario: column resolution scenarios with datasource table +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +USE mydb1; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +USE mydb2; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +-- Scenario: resolve fully qualified table name in star expansion +USE mydb1; +SELECT t1.* FROM t1; +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +USE mydb2; +SELECT t1.* FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +SELECT a.* FROM mydb1.t1 AS a; + +-- Scenario: resolve in case of subquery + +USE mydb1; +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2); +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3); + +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); + +-- TODO: Support this scenario +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); + +-- Scenario: column resolution scenarios in join queries +SET spark.sql.crossJoin.enabled = true; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb2.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; + +USE mydb2; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Scenario: Table with struct column +USE mydb1; +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet; +INSERT INTO t5 VALUES(1, (2, 3)); +SELECT t5.i1 FROM t5; +SELECT t5.t5.i1 FROM t5; +SELECT t5.t5.i1 FROM mydb1.t5; +SELECT t5.i1 FROM mydb1.t5; +SELECT t5.* FROM mydb1.t5; +SELECT t5.t5.* FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i1 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i2 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.* FROM mydb1.t5; + +-- Cleanup and Reset +USE default; +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 5107fa4d55537..b3ec956cd178e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -46,3 +46,6 @@ select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) -- error reporting: aggregate expression select * from values ("one", count(1)), ("two", 2) as data(a, b); + +-- string to timestamp +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql new file mode 100644 index 0000000000000..38739cb950582 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql @@ -0,0 +1,17 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); + +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2; + +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4; + +-- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index b10c41929cdaf..880175fd7add0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -79,7 +79,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC; +ORDER BY t1a DESC, t3b DESC; -- TC 01.03 SELECT Count(DISTINCT(t1a)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql index 6b9e8bf2f362d..5c371d2305ac8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql @@ -287,7 +287,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC; +ORDER BY t1c DESC, t1a DESC; -- TC 01.08 SELECT t1a, @@ -351,7 +351,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last; +ORDER BY t1c DESC NULLS last, t1a DESC; -- TC 01.11 SELECT * @@ -468,5 +468,5 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST; +ORDER BY t1c DESC NULLS LAST, t1i; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index 505366b7acd43..e09b91f18de0a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -85,7 +85,7 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a; +ORDER BY t1a, t3b; -- TC 01.03 SELECT t1a, diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out new file mode 100644 index 0000000000000..60bd8e9cc99db --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -0,0 +1,240 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SET spark.sql.crossJoin.enabled = true +-- !query 6 schema +struct +-- !query 6 output +spark.sql.crossJoin.enabled true + + +-- !query 7 +USE mydb1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +SELECT i1 FROM t1, mydb1.t1 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 9 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 10 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 11 +SELECT i1 FROM t1, mydb2.t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 12 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1, mydb1.t1 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 15 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 16 +SELECT i1 FROM t1, mydb2.t1 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 17 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 18 +SELECT db1.t1.i1 FROM t1, mydb2.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 19 +SET spark.sql.crossJoin.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.crossJoin.enabled false + + +-- !query 20 +USE mydb1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SELECT mydb1.t1 FROM t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 22 +SELECT t1.x.y.* FROM t1 +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 't1.x.y.*' give input columns 'i1'; + + +-- !query 23 +SELECT t1 FROM mydb1.t1 +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 24 +USE mydb2 +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +SELECT mydb1.t1.i1 FROM t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 26 +DROP DATABASE mydb1 CASCADE +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +DROP DATABASE mydb2 CASCADE +-- !query 27 schema +struct<> +-- !query 27 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out new file mode 100644 index 0000000000000..616421d6f2b28 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -0,0 +1,140 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT view1.* FROM view1 +-- !query 1 schema +struct +-- !query 1 output +2 + + +-- !query 2 +SELECT * FROM view1 +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +SELECT view1.i1 FROM view1 +-- !query 3 schema +struct +-- !query 3 output +2 + + +-- !query 4 +SELECT i1 FROM view1 +-- !query 4 schema +struct +-- !query 4 output +2 + + +-- !query 5 +SELECT a.i1 FROM view1 AS a +-- !query 5 schema +struct +-- !query 5 output +2 + + +-- !query 6 +SELECT i1 FROM view1 AS a +-- !query 6 schema +struct +-- !query 6 output +2 + + +-- !query 7 +DROP VIEW view1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1 +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +SELECT * FROM global_temp.view1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT global_temp.view1.* FROM global_temp.view1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'global_temp.view1.*' give input columns 'i1'; + + +-- !query 11 +SELECT i1 FROM global_temp.view1 +-- !query 11 schema +struct +-- !query 11 output +1 + + +-- !query 12 +SELECT global_temp.view1.i1 FROM global_temp.view1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +SELECT view1.i1 FROM global_temp.view1 +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +SELECT a.i1 FROM global_temp.view1 AS a +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +SELECT i1 FROM global_temp.view1 AS a +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +DROP VIEW global_temp.view1 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out new file mode 100644 index 0000000000000..764cad0e3943c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -0,0 +1,447 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 54 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +USE mydb1 +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SELECT i1 FROM t1 +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +SELECT i1 FROM mydb1.t1 +-- !query 8 schema +struct +-- !query 8 output +1 + + +-- !query 9 +SELECT t1.i1 FROM t1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT t1.i1 FROM mydb1.t1 +-- !query 10 schema +struct +-- !query 10 output +1 + + +-- !query 11 +SELECT mydb1.t1.i1 FROM t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 12 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1 +-- !query 14 schema +struct +-- !query 14 output +20 + + +-- !query 15 +SELECT i1 FROM mydb1.t1 +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +SELECT t1.i1 FROM t1 +-- !query 16 schema +struct +-- !query 16 output +20 + + +-- !query 17 +SELECT t1.i1 FROM mydb1.t1 +-- !query 17 schema +struct +-- !query 17 output +1 + + +-- !query 18 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 19 +USE mydb1 +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT t1.* FROM t1 +-- !query 20 schema +struct +-- !query 20 output +1 + + +-- !query 21 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 22 +SELECT t1.* FROM mydb1.t1 +-- !query 22 schema +struct +-- !query 22 output +1 + + +-- !query 23 +USE mydb2 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT t1.* FROM t1 +-- !query 24 schema +struct +-- !query 24 output +20 + + +-- !query 25 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 26 +SELECT t1.* FROM mydb1.t1 +-- !query 26 schema +struct +-- !query 26 output +1 + + +-- !query 27 +SELECT a.* FROM mydb1.t1 AS a +-- !query 27 schema +struct +-- !query 27 output +1 + + +-- !query 28 +USE mydb1 +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2) +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2) +-- !query 31 schema +struct +-- !query 31 output +4 1 + + +-- !query 32 +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 + + +-- !query 33 +SET spark.sql.crossJoin.enabled = true +-- !query 33 schema +struct +-- !query 33 output +spark.sql.crossJoin.enabled true + + +-- !query 34 +SELECT mydb1.t1.i1 FROM t1, mydb2.t1 +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 35 +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 36 +USE mydb2 +-- !query 36 schema +struct<> +-- !query 36 output + + + +-- !query 37 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 38 +SET spark.sql.crossJoin.enabled = false +-- !query 38 schema +struct +-- !query 38 output +spark.sql.crossJoin.enabled false + + +-- !query 39 +USE mydb1 +-- !query 39 schema +struct<> +-- !query 39 output + + + +-- !query 40 +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet +-- !query 40 schema +struct<> +-- !query 40 output + + + +-- !query 41 +INSERT INTO t5 VALUES(1, (2, 3)) +-- !query 41 schema +struct<> +-- !query 41 output + + + +-- !query 42 +SELECT t5.i1 FROM t5 +-- !query 42 schema +struct +-- !query 42 output +1 + + +-- !query 43 +SELECT t5.t5.i1 FROM t5 +-- !query 43 schema +struct +-- !query 43 output +2 + + +-- !query 44 +SELECT t5.t5.i1 FROM mydb1.t5 +-- !query 44 schema +struct +-- !query 44 output +2 + + +-- !query 45 +SELECT t5.i1 FROM mydb1.t5 +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +SELECT t5.* FROM mydb1.t5 +-- !query 46 schema +struct> +-- !query 46 output +1 {"i1":2,"i2":3} + + +-- !query 47 +SELECT t5.t5.* FROM mydb1.t5 +-- !query 47 schema +struct +-- !query 47 output +2 3 + + +-- !query 48 +SELECT mydb1.t5.t5.i1 FROM mydb1.t5 +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 49 +SELECT mydb1.t5.t5.i2 FROM mydb1.t5 +-- !query 49 schema +struct<> +-- !query 49 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 50 +SELECT mydb1.t5.* FROM mydb1.t5 +-- !query 50 schema +struct<> +-- !query 50 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; + + +-- !query 51 +USE default +-- !query 51 schema +struct<> +-- !query 51 output + + + +-- !query 52 +DROP DATABASE mydb1 CASCADE +-- !query 52 schema +struct<> +-- !query 52 output + + + +-- !query 53 +DROP DATABASE mydb2 CASCADE +-- !query 53 schema +struct<> +-- !query 53 output + diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index de6f01b8de772..4e80f0bda5513 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query 0 @@ -143,3 +143,11 @@ struct<> -- !query 15 output org.apache.spark.sql.AnalysisException cannot evaluate expression count(1) in inline table definition; line 1 pos 29 + + +-- !query 16 +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b) +-- !query 16 schema +struct> +-- !query 16 output +1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] diff --git a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out new file mode 100644 index 0000000000000..8d56ebe9fd3b4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out @@ -0,0 +1,67 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag +-- !query 6 schema +struct +-- !query 6 output +1 a +1 a +1 b +1 b diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out index 7258bcfc6ab72..ab6a11a2b7efa 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out @@ -102,7 +102,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC +ORDER BY t1a DESC, t3b DESC -- !query 4 schema struct -- !query 4 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out index 878bc755ef5fc..e06f9206d3401 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out @@ -353,7 +353,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC +ORDER BY t1c DESC, t1a DESC -- !query 9 schema struct -- !query 9 output @@ -445,7 +445,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last +ORDER BY t1c DESC NULLS last, t1a DESC -- !query 12 schema struct -- !query 12 output @@ -580,16 +580,16 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST +ORDER BY t1c DESC NULLS LAST, t1i -- !query 15 schema struct -- !query 15 output -1 8 16 2014-05-05 1 8 16 2014-05-04 +1 8 16 2014-05-05 1 16 12 2014-06-04 1 16 12 2014-07-04 1 6 8 2014-04-04 +1 10 NULL 2014-05-04 1 10 NULL 2014-08-04 1 10 NULL 2014-09-04 1 10 NULL 2015-05-04 -1 10 NULL 2014-05-04 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out index db01fa455735c..bae5d00cc8632 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out @@ -112,12 +112,12 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a +ORDER BY t1a, t3b -- !query 4 schema struct -- !query 4 output -val1c 8 16 1 10 12 val1c 8 16 1 6 12 +val1c 8 16 1 10 12 val1c 8 16 1 17 16 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1af1a3652971c..2a0e088437fda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -634,4 +634,20 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedPlan2) == 4) } } + + test("refreshByPath should refresh all cached plans with the specified path") { + withTempDir { dir => + val path = dir.getCanonicalPath() + + spark.range(10).write.mode("overwrite").parquet(path) + spark.read.parquet(path).cache() + spark.read.parquet(path).filter($"id" > 4).cache() + assert(spark.read.parquet(path).filter($"id" > 4).count() == 5) + + spark.range(20).write.mode("overwrite").parquet(path) + spark.catalog.refreshByPath(path) + assert(spark.read.parquet(path).count() == 20) + assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e65436079db2..19c2d5532d088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -914,15 +914,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = spark.read.json(sparkContext.makeRDD( - """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) + val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = spark.read.json(sparkContext.makeRDD( - """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) + val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) checkAnswer( df2.select(df2("`a b`.c.d e.f")), Row(1) @@ -1110,8 +1108,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = spark.read.json(sparkContext.makeRDD( - """{"a": {"b": 1}}""" :: Nil)) + val df = spark.read.json(Seq("""{"a": {"b": 1}}""").toDS()) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1164,8 +1161,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = spark.read.json(spark.sparkContext.makeRDD( - (1 to 10).map(i => s"""{"id": $i}"""))) + val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS()) val df = input.select($"id", rand(0).as('r)) df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03cdfccdda555..468ea0551298e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -211,8 +211,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - spark.read.json(sparkContext.parallelize( - """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + spark.read + .json(Seq("""{"nested": {"attribute": 1}, "value": 2}""").toDS()) .createOrReplaceTempView("rows") checkAnswer( @@ -229,9 +229,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - spark.read.json( - sparkContext.parallelize( - Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + spark.read + .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}").toDS()) .createOrReplaceTempView("d") checkAnswer( @@ -240,9 +239,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - spark.read.json( - sparkContext.parallelize( - Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) + spark.read + .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "").toDS()) .createOrReplaceTempView("d") checkAnswer( @@ -1214,8 +1212,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( - Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = Seq("""{"key?number1": "value1", "key.number2": "value2"}""").toDS() spark.read.json(data).createOrReplaceTempView("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1257,13 +1254,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + spark.read.json(Seq("""{"a": {"b": [{"c": 1}]}}""").toDS()) .createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) spark.catalog.dropTempView("data") - spark.read.json( - sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).createOrReplaceTempView("data") + spark.read.json(Seq("""{"a": {"b": 1}}""").toDS()) + .createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) spark.catalog.dropTempView("data") } @@ -1311,8 +1308,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: ORDER BY test for nested fields") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + spark.read + .json(Seq("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""").toDS()) .createOrReplaceTempView("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1325,7 +1322,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-6145: special cases") { spark.read - .json(sparkContext.makeRDD("""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)) + .json(Seq("""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""").toDS()) .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) @@ -1333,8 +1330,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6898: complete support for special chars in column names") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + spark.read + .json(Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS()) .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) @@ -1437,8 +1434,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempView("t") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).createOrReplaceTempView("t") + spark.read + .json(Seq("""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""").toDS()) + .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } @@ -2109,8 +2107,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}}' | """.stripMargin - val rdd = sparkContext.parallelize(Array(json)) - spark.read.json(rdd).write.mode("overwrite").parquet(dir.toString) + spark.read.json(Seq(json).toDS()).write.mode("overwrite").parquet(dir.toString) spark.read.parquet(dir.toString).collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 91aecca537fb2..68ababcd11027 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -98,7 +98,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** List of test cases to ignore, in lower cases. */ private val blackList = Set( - "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + "blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality. + ".DS_Store" // A meta-file that may be created on Mac by Finder App. + // We should ignore this file from processing. ) // Create all the test cases. @@ -121,7 +123,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.contains(testCase.name.toLowerCase)) { + if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { @@ -226,12 +228,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { - case a: AnalysisException if a.plan.nonEmpty => + case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. // Also implement a crude way of masking expression IDs in the error message // with a generic pattern "###". - (StructType(Seq.empty), - Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x"))) + val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage + (StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) @@ -241,7 +243,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def listTestCases(): Seq[TestCase] = { listFilesRecursively(new File(inputFilePath)).map { file => val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" - TestCase(file.getName, file.getAbsolutePath, resultFile) + val absPath = file.getAbsolutePath + val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) + TestCase(testCaseName, absPath, resultFile) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b38bbd8e7eef2..bbb31dbc8f3de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -306,7 +306,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils // Analyze only one column. sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.catalogTable) + case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index c7a77daacab7e..b096a6db8517f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -221,8 +221,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT StructField("vec", new UDT.MyDenseVectorUDT, false) )) - val stringRDD = sparkContext.parallelize(data) - val jsonRDD = spark.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(data.toDS()) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -242,8 +241,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT StructField("vec", new UDT.MyDenseVectorUDT, false) )) - val stringRDD = sparkContext.parallelize(data) - val jsonDataset = spark.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(data.toDS()) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index d2704b3d3f371..a42891e55a18a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -140,7 +140,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } datum += "}" datum = s"""{"a": {"b": {"c": $datum, "d": $datum}, "e": $datum}}""" - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$width wide x $numRows rows", "a.b.c.value_1") } @@ -157,7 +157,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { datum = "{\"value\": " + datum + "}" selector = selector + ".value" } - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$depth deep x $numRows rows", selector) } @@ -180,7 +180,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } // TODO(ekl) seems like the json parsing is actually the majority of the time, perhaps // we should benchmark that too separately. - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$numNodes x $depth deep x $numRows rows", selector) } @@ -200,7 +200,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } } datum += "]}" - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$width wide x $numRows rows", "value[0]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b44f20e367f0a..8b8cd0fdf4db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1836,18 +1836,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert data to a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$path") + |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == dir.getAbsolutePath) dir.delete - val tableLocFile = new File(table.location.stripPrefix("file:")) + val tableLocFile = new File(table.location) assert(!tableLocFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") assert(tableLocFile.exists) @@ -1878,16 +1877,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert into a data source table with no existed partition location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$path" + |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == dir.getAbsolutePath) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1906,15 +1904,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("read data from a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$path") + |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == dir.getAbsolutePath) dir.delete() checkAnswer(spark.table("t"), Nil) @@ -1939,7 +1936,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "${dir.toURI}" + |LOCATION "$dir" """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1952,4 +1949,51 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == dir.getAbsolutePath) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == dir.getAbsolutePath) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 2b4c9f3ed3274..7ea4064927576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileIndexSuite extends SharedSQLContext { @@ -178,6 +179,47 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog2.allFiles().nonEmpty) } } + + test("InMemoryFileIndex with empty rootPaths when PARALLEL_PARTITION_DISCOVERY_THRESHOLD" + + "is a nonpositive number") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "0") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "-1") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + }.getMessage + assert(e.contains("The maximum number of paths allowed for listing files at " + + "driver side must not be negative")) + } + + test("refresh for InMemoryFileIndex with FileStatusCache") { + withTempDir { dir => + val fileStatusCache = FileStatusCache.getOrCreate(spark) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val catalog = + new InMemoryFileIndex(spark, Seq(dirPath), Map.empty, None, fileStatusCache) { + def leafFilePaths: Seq[Path] = leafFiles.keys.toSeq + def leafDirPaths: Seq[Path] = leafDirToChildrenFiles.keys.toSeq + } + + val file = new File(dir, "text.txt") + stringToFile(file, "text") + assert(catalog.leafDirPaths.isEmpty) + assert(catalog.leafFilePaths.isEmpty) + + catalog.refresh() + + assert(catalog.leafFilePaths.size == 1) + assert(catalog.leafFilePaths.head == fs.makeQualified(new Path(file.getAbsolutePath))) + + assert(catalog.leafDirPaths.size == 1) + assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 371d4311baa3b..56071803f685f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -24,11 +24,12 @@ import java.text.SimpleDateFormat import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -243,12 +244,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - val cars = spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { wholeFile => + val cars = spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } test("test for blank column names on read and select columns") { @@ -263,14 +267,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - val exception = intercept[SparkException]{ - spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "failfast")) - .load(testFile(carsFile)).collect() - } + Seq(false, true).foreach { wholeFile => + val exception = intercept[SparkException] { + spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } } test("test for tokens more than the fields in the schema") { @@ -735,10 +742,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601timestampsPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601Timestamps = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601timestampsPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) @@ -768,10 +776,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601datesPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601dates = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601datesPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) @@ -826,10 +835,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(datesWithFormatPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringDatesWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(datesWithFormatPath) val expectedStringDatesWithFormat = Seq( Row("2015/08/26"), @@ -857,10 +867,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/26 18:00"), @@ -889,10 +900,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/27 01:00"), @@ -961,56 +973,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - val schema = new StructType().add("a", IntegerType).add("b", TimestampType) - val df1 = spark - .read - .option("mode", "PERMISSIVE") - .schema(schema) - .csv(testFile(valueMalformedFile)) - checkAnswer(df1, - Row(null, null) :: - Row(1, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records - val columnNameOfCorruptRecord = "_unparsed" - val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) - val df2 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField1) - .csv(testFile(valueMalformedFile)) - checkAnswer(df2, - Row(null, null, "0,2013-111-11 12:13:14") :: - Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: - Nil) - - // We put a `columnNameOfCorruptRecord` field in the middle of a schema - val schemaWithCorrField2 = new StructType() - .add("a", IntegerType) - .add(columnNameOfCorruptRecord, StringType) - .add("b", TimestampType) - val df3 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField2) - .csv(testFile(valueMalformedFile)) - checkAnswer(df3, - Row(null, "0,2013-111-11 12:13:14", null) :: - Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - val errMsg = intercept[AnalysisException] { - spark + Seq(false, true).foreach { wholeFile => + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .option("wholeFile", wholeFile) + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField1) + .csv(testFile(valueMalformedFile)) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) - .collect - }.getMessage - assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } + + test("SPARK-19610: Parse normal multi-line CSV files") { + val primitiveFieldAndType = Seq( + """" + |string","integer + | + | + |","long + | + |","bigInteger",double,boolean,null""".stripMargin, + """"this is a + |simple + |string."," + | + |10"," + |21474836470","92233720368547758070"," + | + |1.7976931348623157E308",true,""".stripMargin) + + withTempPath { path => + primitiveFieldAndType.toDF("value").coalesce(1).write.text(path.getAbsolutePath) + + val df = spark.read + .option("header", true) + .option("wholeFile", true) + .csv(path.getAbsolutePath) + + // Check if headers have new lines in the names. + val actualFields = df.schema.fieldNames.toSeq + val expectedFields = + Seq("\nstring", "integer\n\n\n", "long\n\n", "bigInteger", "double", "boolean", "null") + assert(actualFields === expectedFields) + + // Check if the rows have new lines in the values. + val expected = Row( + "this is a\nsimple\nstring.", + "\n\n10", + "\n21474836470", + "92233720368547758070", + "\n\n1.7976931348623157E308", + "true", + null) + checkAnswer(df, expected) + } + } + + test("Empty file produces empty dataframe with empty schema - wholeFile option") { + withTempPath { path => + path.createNewFile() + + val df = spark.read.format("csv") + .option("header", true) + .option("wholeFile", true) + .load(path.getAbsolutePath) + + assert(df.schema === spark.emptyDataFrame.schema) + checkAnswer(df, spark.emptyDataFrame) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 0b72da5f3759c..6e2b4f0df595f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -25,19 +25,18 @@ import org.apache.spark.sql.test.SharedSQLContext * Test cases for various [[JSONOptions]]. */ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("allowComments off") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowComments on") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowComments", "true").json(rdd) + val df = spark.read.option("allowComments", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -45,16 +44,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowSingleQuotes off") { val str = """{'name': 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowSingleQuotes", "false").json(rdd) + val df = spark.read.option("allowSingleQuotes", "false").json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowSingleQuotes on") { val str = """{'name': 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -62,16 +59,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowUnquotedFieldNames off") { val str = """{name: 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowUnquotedFieldNames on") { val str = """{name: 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd) + val df = spark.read.option("allowUnquotedFieldNames", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -79,16 +74,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowNumericLeadingZeros on") { val str = """{"age": 0018}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd) + val df = spark.read.option("allowNumericLeadingZeros", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "age") assert(df.first().getLong(0) == 18) @@ -98,16 +91,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. ignore("allowNonNumericNumbers off") { val str = """{"age": NaN}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } ignore("allowNonNumericNumbers on") { val str = """{"age": NaN}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd) + val df = spark.read.option("allowNonNumericNumbers", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) @@ -115,16 +106,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowBackslashEscapingAnyCharacter on") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.schema.last.name == "price") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0e01be2410409..0aaf148dac258 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -590,7 +590,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val expectedSchema = StructType( @@ -622,7 +622,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.option("primitivesAsString", "true").json(path) val expectedSchema = StructType( @@ -777,9 +777,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { - val mixedIntegerAndDoubleRecords = sparkContext.parallelize( - """{"a": 3, "b": 1.1}""" :: - s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) + val mixedIntegerAndDoubleRecords = Seq( + """{"a": 3, "b": 1.1}""", + s"""{"a": 3.1, "b": 0.${"0" * 38}1}""").toDS() val jsonDF = spark.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) @@ -828,7 +828,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val mergedJsonDF = spark.read .option("prefersDecimal", "true") - .json(floatingValueRecords ++ bigIntegerRecords) + .json(floatingValueRecords.union(bigIntegerRecords)) val expectedMergedSchema = StructType( StructField("a", DoubleType, true) :: @@ -846,7 +846,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.toURI.toString - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) sql( s""" @@ -873,7 +873,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val schema = StructType( StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: @@ -1263,7 +1263,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") val jsonDF = spark.read.json(primitiveFieldAndType) - val primTable = spark.read.json(jsonDF.toJSON.rdd) + val primTable = spark.read.json(jsonDF.toJSON) primTable.createOrReplaceTempView("primitiveTable") checkAnswer( sql("select * from primitiveTable"), @@ -1276,7 +1276,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) val complexJsonDF = spark.read.json(complexFieldAndType1) - val compTable = spark.read.json(complexJsonDF.toJSON.rdd) + val compTable = spark.read.json(complexJsonDF.toJSON) compTable.createOrReplaceTempView("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1364,10 +1364,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { }) } - test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { + test("SPARK-6245 JsonInferSchema.infer on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, + empty.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1394,7 +1394,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, + emptyRecords.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1592,7 +1592,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - arrayAndStructRecords.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + arrayAndStructRecords.map(record => record.replaceAll("\n", " ")).write.text(path) val schema = StructType( @@ -1609,7 +1609,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath @@ -1645,7 +1645,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath @@ -1693,8 +1693,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val json = s""" |{"a": [{$nested}], "b": [{$nested}]} """.stripMargin - val rdd = spark.sparkContext.makeRDD(Seq(json)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(json).toDS()) assert(df.schema.size === 2) df.collect() } @@ -1794,8 +1793,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val records = sparkContext - .parallelize("""{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 0.000001}""" :: Nil) + val records = Seq("""{"a": 3, "b": 1.1}""", """{"a": 3.1, "b": 0.000001}""").toDS() val schema = StructType( StructField("a", DecimalType(21, 1), true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index a400940db924a..13084ba4a7f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.json -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} private[json] trait TestJsonData { protected def spark: SparkSession - def primitiveFieldAndType: RDD[String] = - spark.sparkContext.parallelize( + def primitiveFieldAndType: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,10 +31,10 @@ private[json] trait TestJsonData { "double":1.7976931348623157E308, "boolean":true, "null":null - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def primitiveFieldValueTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -44,16 +43,17 @@ private[json] trait TestJsonData { "num_bool":false, "num_str":"str1", "str_bool":false}""" :: """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) + )(Encoders.STRING) - def jsonNullStruct: RDD[String] = - spark.sparkContext.parallelize( + def jsonNullStruct: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: - """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) + """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil))(Encoders.STRING) - def complexFieldValueTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldValueTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -62,24 +62,25 @@ private[json] trait TestJsonData { "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) + )(Encoders.STRING) - def arrayElementTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def arrayElementTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: - """{"array3": [1, 2, 3]}""" :: Nil) + """{"array3": [1, 2, 3]}""" :: Nil))(Encoders.STRING) - def missingFields: RDD[String] = - spark.sparkContext.parallelize( + def missingFields: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: - """{"e":"str"}""" :: Nil) + """{"e":"str"}""" :: Nil))(Encoders.STRING) - def complexFieldAndType1: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldAndType1: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,10 +93,10 @@ private[json] trait TestJsonData { "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def complexFieldAndType2: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldAndType2: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,89 +147,90 @@ private[json] trait TestJsonData { {"inner3": [[{"inner4": 2}]]} ] ]] - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def mapType1: RDD[String] = - spark.sparkContext.parallelize( + def mapType1: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: - """{"map": {"e": null}}""" :: Nil) + """{"map": {"e": null}}""" :: Nil))(Encoders.STRING) - def mapType2: RDD[String] = - spark.sparkContext.parallelize( + def mapType2: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: """{"map": {"e": null}}""" :: - """{"map": {"f": {"field1": null}}}""" :: Nil) + """{"map": {"f": {"field1": null}}}""" :: Nil))(Encoders.STRING) - def nullsInArrays: RDD[String] = - spark.sparkContext.parallelize( + def nullsInArrays: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: - """{"field4":[[null, [1,2,3]]]}""" :: Nil) + """{"field4":[[null, [1,2,3]]]}""" :: Nil))(Encoders.STRING) - def jsonArray: RDD[String] = - spark.sparkContext.parallelize( + def jsonArray: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """[]""" :: Nil) + """[]""" :: Nil))(Encoders.STRING) - def corruptRecords: RDD[String] = - spark.sparkContext.parallelize( + def corruptRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: """{"a":{, b:3}""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """]""" :: Nil) + """]""" :: Nil))(Encoders.STRING) - def additionalCorruptRecords: RDD[String] = - spark.sparkContext.parallelize( + def additionalCorruptRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"dummy":"test"}""" :: """[1,2,3]""" :: """":"test", "a":1}""" :: """42""" :: - """ ","ian":"test"}""" :: Nil) + """ ","ian":"test"}""" :: Nil))(Encoders.STRING) - def emptyRecords: RDD[String] = - spark.sparkContext.parallelize( + def emptyRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: """{"a": {"b": {}}}""" :: """{"b": [{"c": {}}]}""" :: - """]""" :: Nil) + """]""" :: Nil))(Encoders.STRING) - def timestampAsLong: RDD[String] = - spark.sparkContext.parallelize( - """{"ts":1451732645}""" :: Nil) + def timestampAsLong: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + """{"ts":1451732645}""" :: Nil))(Encoders.STRING) - def arrayAndStructRecords: RDD[String] = - spark.sparkContext.parallelize( + def arrayAndStructRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"a": {"b": 1}}""" :: - """{"a": []}""" :: Nil) + """{"a": []}""" :: Nil))(Encoders.STRING) - def floatingValueRecords: RDD[String] = - spark.sparkContext.parallelize( - s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) + def floatingValueRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil))(Encoders.STRING) - def bigIntegerRecords: RDD[String] = - spark.sparkContext.parallelize( - s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) + def bigIntegerRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil))(Encoders.STRING) - def datesRecords: RDD[String] = - spark.sparkContext.parallelize( + def datesRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"date": "26/08/2015 18:00"}""" :: """{"date": "27/10/2014 18:30"}""" :: - """{"date": "28/01/2016 20:00"}""" :: Nil) + """{"date": "28/01/2016 20:00"}""" :: Nil))(Encoders.STRING) - lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) - def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) + def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 420cff878fa0d..88cb8a0bad21e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -51,9 +52,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME + val timeZone = TimeZone.getDefault() + val timeZoneId = timeZone.getID + test("column type inference") { - def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, true) === literal) + def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { + assert(inferPartitionColumnValue(raw, true, timeZone) === literal) } check("10", Literal.create(10, IntegerType)) @@ -66,6 +70,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType)) check("1990-02-24 12:00:30", Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType)) + + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + c.set(1990, 1, 24, 12, 0, 30) + c.set(Calendar.MILLISECOND, 0) + check("1990-02-24 12:00:30", + Literal.create(new Timestamp(c.getTimeInMillis), TimestampType), + TimeZone.getTimeZone("GMT")) + check(defaultPartitionName, Literal.create(null, NullType)) } @@ -77,7 +89,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -90,7 +102,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) // Valid paths = Seq( @@ -102,7 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/something=true/table"))) + Set(new Path("hdfs://host:9000/path/something=true/table")), + timeZoneId) // Valid paths = Seq( @@ -114,7 +128,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/table=true"))) + Set(new Path("hdfs://host:9000/path/table=true")), + timeZoneId) // Invalid paths = Seq( @@ -126,7 +141,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -145,20 +161,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/tmp/tables/"))) + Set(new Path("hdfs://host:9000/tmp/tables/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path], timeZone) }.getMessage assert(message.contains(expected)) @@ -201,7 +218,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path/a=10")))._1 + basePaths = Set(new Path("file://path/a=10")), + timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -209,7 +227,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path")))._1 + basePaths = Set(new Path("file://path")), + timeZone = timeZone)._1 assert(partitionSpec2 == Option(PartitionValues( @@ -226,7 +245,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - rootPaths) + rootPaths, + timeZoneId) assert(actualSpec === spec) } @@ -307,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId) assert(actualSpec === spec) } @@ -686,6 +706,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name).cast(f.dataType)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("Various inferred partition value types") { @@ -720,6 +747,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("SPARK-8037: Ignores files whose name starts with dot") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index d7d7176c48a3a..200e356c72fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -77,8 +77,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val df = spark.read.parquet(path).cache() assert(df.count() == 1000) spark.range(10).write.mode("overwrite").parquet(path) - assert(df.count() == 1000) - spark.catalog.refreshByPath(path) assert(df.count() == 10) assert(spark.read.parquet(path).count() == 10) } @@ -91,8 +89,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val df = spark.read.parquet(path).cache() assert(df.count() == 1000) spark.range(10).write.mode("append").parquet(path) - assert(df.count() == 1000) - spark.catalog.refreshByPath(path) assert(df.count() == 1010) assert(spark.read.parquet(path).count() == 1010) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6b38b6a097213..e848f74e3159f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -210,13 +212,6 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - - // Overwrite the version with other data - val store2 = provider.getStore(1) - put(store2, "c", 1) - assert(store2.commit() === 2) - assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) } test("snapshotting") { @@ -292,6 +287,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } + test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + val conf = new Configuration() + conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) + conf.set("fs.default.name", "fake:///") + + val provider = newStoreProvider(hadoopConf = conf) + provider.getStore(0).commit() + provider.getStore(0).commit() + + // Verify we don't leak temp files + val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + null, true).asScala.filter(_.getName.startsWith("temp-")) + assert(tempFiles.isEmpty) + } test("corrupted file handling") { val provider = newStoreProvider(minDeltasForSnapshot = 5) @@ -681,6 +690,21 @@ private[state] object StateStoreSuite { } } +/** + * Fake FileSystem that simulates HDFS rename semantic, i.e. renaming a file atop an existing + * one should return false. + * See hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html + */ +class RenameLikeHDFSFileSystem extends RawLocalFileSystem { + override def rename(src: Path, dst: Path): Boolean = { + if (exists(dst)) { + return false + } else { + return super.rename(src, dst) + } + } +} + /** * Fake FileSystem to test that the StateStore throws an exception while committing the * delta file, when `fs.rename` returns `false`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 4a42f8ea79cf3..916a01ee0ca8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -33,14 +33,15 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 4fc2f81f540bc..2eae66dda88de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { + import testImplicits._ + protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") sql( s""" |CREATE TEMPORARY VIEW jsonTable (a int, b string) @@ -129,7 +131,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) - spark.read.json(rdd1).createOrReplaceTempView("jt1") + spark.read.json(rdd1.toDS()).createOrReplaceTempView("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -141,7 +143,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) - spark.read.json(rdd2).createOrReplaceTempView("jt2") + spark.read.json(rdd1.toDS()).createOrReplaceTempView("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -279,15 +281,15 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { """.stripMargin) // jsonTable should be recached. assertCached(sql("SELECT * FROM jsonTable")) - // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation -// // The cached data is the new data. -// checkAnswer( -// sql("SELECT a, b FROM jsonTable"), -// sql("SELECT a * 2, b FROM jt").collect()) -// -// // Verify uncaching -// spark.catalog.uncacheTable("jsonTable") -// assertCached(sql("SELECT * FROM jsonTable"), 0) + + // The cached data is the new data. + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a * 2, b FROM jt").collect()) + + // Verify uncaching + spark.catalog.uncacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index bf7fabe33266b..f251290583c5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -124,6 +126,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { } } + test("timeZone setting in dynamic partition writes") { + def checkPartitionValues(file: File, expected: String): Unit = { + val dir = file.getParentFile() + val value = ExternalCatalogUtils.unescapePathName( + dir.getName.substring(dir.getName.indexOf("=") + 1)) + assert(value == expected) + } + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((1, ts)).toDF("i", "ts") + withTempPath { f => + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + checkPartitionValues(files.head, "2016-12-01 00:00:00") + } + withTempPath { f => + df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // use timeZone option "GMT" to format partition value. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + withTempPath { f => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // if there isn't timeZone option, then use session local timezone. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + } + } + /** Lists files recursively. */ private def recursiveList(f: File): Array[File] = { require(f.isDirectory) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index b1756c27fae0a..773d34dfaf9a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + import testImplicits._ + protected override lazy val sql = spark.sql _ private var originalDefaultSource: String = null private var path: File = null @@ -40,8 +42,8 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA path = Utils.createTempDir() path.delete() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = spark.read.json(rdd) + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""").toDS() + df = spark.read.json(ds) df.createOrReplaceTempView("jsonTable") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5110d89c85b15..1586850c77fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -52,10 +52,7 @@ abstract class FileStreamSourceTest query.nonEmpty, "Cannot add data when there is no query for finding the active file stream source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - } + val sources = getSourcesFromStreamingQuery(query.get) if (sources.isEmpty) { throw new Exception( "Could not find file source in the StreamExecution logical plan to add data to") @@ -134,6 +131,14 @@ abstract class FileStreamSourceTest }.head } + protected def getSourcesFromStreamingQuery(query: StreamExecution): Seq[FileStreamSource] = { + query.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => + source.asInstanceOf[FileStreamSource] + } + } + + protected def withTempDirs(body: (File, File) => Unit) { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") @@ -388,9 +393,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("a", "b", "c", "d"), AssertOnQuery("seen files should contain only one entry") { streamExecution => - val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => - e.source.asInstanceOf[FileStreamSource] - }.head + val source = getSourcesFromStreamingQuery(streamExecution).head assert(source.seenFiles.size == 1) true } @@ -662,6 +665,101 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("read data from outputs of another streaming query") { + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + val q1 = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + .start(outputDir.getCanonicalPath) + + // q2 is a streaming query that reads q1's text outputs + val q2 = + createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + def q1AddData(data: String*): StreamAction = + Execute { _ => + q1Source.addData(data) + q1.processAllAvailable() + } + def q2ProcessAllAvailable(): StreamAction = Execute { q2 => q2.processAllAvailable() } + + testStream(q2)( + // batch 0 + q1AddData("drop1", "keep2"), + q2ProcessAllAvailable(), + CheckAnswer("keep2"), + + // batch 1 + Assert { + // create a text file that won't be on q1's sink log + // thus even if its content contains "keep", it should NOT appear in q2's answer + val shouldNotKeep = new File(outputDir, "should_not_keep.txt") + stringToFile(shouldNotKeep, "should_not_keep!!!") + shouldNotKeep.exists() + }, + q1AddData("keep3"), + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3"), + + // batch 2: check that things work well when the sink log gets compacted + q1AddData("keep4"), + Assert { + // compact interval is 3, so file "2.compact" should exist + new File(outputDir, s"${FileStreamSink.metadataDir}/2.compact").exists() + }, + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3", "keep4"), + + Execute { _ => q1.stop() } + ) + } + } + } + + test("start before another streaming query, and read its output") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + // define q1, but don't start it for now + val q1Write = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + var q1: StreamingQuery = null + + val q2 = createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + testStream(q2)( + AssertOnQuery { q2 => + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has not started yet, verify that q2 doesn't know whether q1 has metadata + fileSource.sourceHasMetadata === None + }, + Execute { _ => + q1 = q1Write.start(outputDir.getCanonicalPath) + q1Source.addData("drop1", "keep2") + q1.processAllAvailable() + }, + AssertOnQuery { q2 => + q2.processAllAvailable() + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has started, verify that q2 knows q1 has metadata by now + fileSource.sourceHasMetadata === Some(true) + }, + CheckAnswer("keep2"), + Execute { _ => q1.stop() } + ) + } + } + test("when schema inference is turned on, should read partition data") { def createFile(content: String, src: File, tmp: File): Unit = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) @@ -755,10 +853,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .streamingQuery q.processAllAvailable() val memorySink = q.sink.asInstanceOf[MemorySink] - val fileSource = q.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - }.head + val fileSource = getSourcesFromStreamingQuery(q).head /** Check the data read in the last batch */ def checkLastBatchData(data: Int*): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f44cfada29e2b..6dfcd8baba20e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming +import java.io.{InterruptedIOException, IOException} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} + import scala.reflect.ClassTag import scala.util.control.ControlThrowable @@ -350,13 +353,45 @@ class StreamSuite extends StreamTest { } } } -} -/** - * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. - */ -class FakeDefaultSource extends StreamSourceProvider { + test("handle IOException when the streaming thread is interrupted (pre Hadoop 2.8)") { + // This test uses a fake source to throw the same IOException as pre Hadoop 2.8 when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingIOExceptionLikeHadoop12074.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingIOExceptionLikeHadoop12074].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingIOExceptionLikeHadoop12074.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingIOExceptionLikeHadoop12074.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + test("handle InterruptedIOException when the streaming thread is interrupted (Hadoop 2.8+)") { + // This test uses a fake source to throw the same InterruptedIOException as Hadoop 2.8+ when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingInterruptedIOException.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingInterruptedIOException].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingInterruptedIOException.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingInterruptedIOException.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } +} + +abstract class FakeSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( @@ -364,6 +399,10 @@ class FakeDefaultSource extends StreamSourceProvider { schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) +} + +/** A fake StreamSourceProvider that creates a fake Source that cannot be reused. */ +class FakeDefaultSource extends FakeSource { override def createSource( spark: SQLContext, @@ -395,3 +434,63 @@ class FakeDefaultSource extends StreamSourceProvider { } } } + +/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */ +class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { + import ThrowingIOExceptionLikeHadoop12074._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + throw new IOException(ie.toString) + } + } +} + +object ThrowingIOExceptionLikeHadoop12074 { + /** + * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} + +/** A fake source that throws InterruptedIOException like Hadoop 2.8+ when it's interrupted. */ +class ThrowingInterruptedIOException extends FakeSource { + import ThrowingInterruptedIOException._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + val iie = new InterruptedIOException(ie.toString) + iie.initCause(ie) + throw iie + } + } +} + +object ThrowingInterruptedIOException { + /** + * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af2f31a34d8da..60e2375a9817d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -208,6 +208,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + /** Execute arbitrary code */ + object Execute { + def apply(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }) + } + class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { private var waitStartTime: Option[Long] = None @@ -472,7 +478,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AssertOnQuery => verify(currentStream != null || lastStream != null, - "cannot assert when not stream has been started") + "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 4596aa1d348e3..eb09b9ffcfc5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -133,6 +133,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("SPARK-19594: all of listeners should receive QueryTerminatedEvent") { + val df = MemoryStream[Int].toDS().as[Long] + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append)( + StartStream(), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1525ad5fd5178..a0a2b2b4c9b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils +import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.mock.MockitoSugar import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,11 +34,11 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.BlockingSource +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} import org.apache.spark.util.ManualClock -class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { import AwaitTerminationTester._ import testImplicits._ @@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { } } + test("StreamExecution should call stop() on sources when a stream is stopped") { + var calledStop = false + val source = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + + testStream(df)(StopStream) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + + testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") { + var calledStop = false + val source1 = new Source { + override def stop(): Unit = { + throw new RuntimeException("Oh no!") + } + override def getOffset: Option[Offset] = Some(LongOffset(1)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*) + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + val source2 = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source1, source2) { + val df1 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + val df2 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + testStream(df1.union(df2).map(i => i / 0))( + AssertOnQuery { sq => + intercept[StreamingQueryException](sq.processAllAvailable()) + sq.exception.isDefined && !sq.isActive + } + ) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala new file mode 100644 index 0000000000000..0bf05381a7f36 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * A StreamSourceProvider that provides mocked Sources for unit testing. Example usage: + * + * {{{ + * MockSourceProvider.withMockSources(source1, source2) { + * val df1 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * val df2 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * df1.union(df2) + * ... + * } + * }}} + */ +class MockSourceProvider extends StreamSourceProvider { + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + ("dummySource", MockSourceProvider.fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + MockSourceProvider.sourceProviderFunction() + } +} + +object MockSourceProvider { + // Function to generate sources. May provide multiple sources if the user implements such a + // function. + private var sourceProviderFunction: () => Source = _ + + final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = { + var i = 0 + val sources = source +: otherSources + sourceProviderFunction = () => { + val source = sources(i % sources.length) + i += 1 + source + } + try { + f + } finally { + sourceProviderFunction = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 78a309497ab57..c0b299411e94a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -40,6 +40,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeAppName = sparkConf .getOption("spark.app.name") .filterNot(_ == classOf[SparkSQLCLIDriver].getName) + .filterNot(_ == classOf[HiveThriftServer2].getName) sparkConf .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 7ba5790c2979d..c7d953a731b9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -95,9 +95,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // This is used to generate golden files. sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - sql(s"set fs.default.name=file://$testTempDir/") + sql(s"set fs.defaultFS=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - sql("set mapred.job.tracker=local") + sql("set mapreduce.jobtracker.address=local") } override def afterAll() { @@ -764,9 +764,9 @@ class HiveWindowFunctionQueryFileSuite // This is used to generate golden files. // sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - // sql(s"set fs.default.name=file://$testTempDir/") + // sql(s"set fs.defaultFS=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - // sql("set mapred.job.tracker=local") + // sql("set mapreduce.jobtracker.address=local") } override def afterAll() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 50bb44f7d4e6e..43d9c2bec6823 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient @@ -1008,7 +1008,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) val partitionColumnNames = catalogTable.partitionColumnNames.toSet @@ -1034,7 +1035,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) - clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) } + clientPrunedPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } } else { client.getPartitions(catalogTable).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 677da0dbdc654..151a69aebf1d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path @@ -26,6 +28,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.orc.OrcFileFormat @@ -71,10 +74,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def getCached( tableIdentifier: QualifiedTableName, pathsInMetastore: Seq[Path], - metastoreRelation: MetastoreRelation, schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], - expectedBucketSpec: Option[BucketSpec], partitionSchema: Option[StructType]): Option[LogicalRelation] = { tableRelationCache.getIfPresent(tableIdentifier) match { @@ -89,7 +90,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val useCached = relation.location.rootPaths.toSet == pathsInMetastore.toSet && logical.schema.sameType(schemaInMetastore) && - relation.bucketSpec == expectedBucketSpec && + // We don't support hive bucketed tables. This function `getCached` is only used for + // converting supported Hive tables to data source tables. + relation.bucketSpec.isEmpty && relation.partitionSchema == partitionSchema.getOrElse(StructType(Nil)) if (useCached) { @@ -100,52 +103,48 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log None } case _ => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} " + - s"should be stored as $expectedFileFormat. However, we are getting " + - s"a ${relation.fileFormat} from the metastore cache. This cached " + - s"entry will be invalidated.") + logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + + "This cached entry will be invalidated.") tableRelationCache.invalidate(tableIdentifier) None } case other => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as $expectedFileFormat. However, we are getting a $other from the metastore cache. " + - s"This cached entry will be invalidated.") + logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + + s"However, we are getting a $other from the metastore cache. " + + "This cached entry will be invalidated.") tableRelationCache.invalidate(tableIdentifier) None } } private def convertToLogicalRelation( - metastoreRelation: MetastoreRelation, + relation: CatalogRelation, options: Map[String, String], - defaultSource: FileFormat, fileFormatClass: Class[_ <: FileFormat], fileType: String): LogicalRelation = { - val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val metastoreSchema = relation.tableMeta.schema val tableIdentifier = - QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) - val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. + QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions - val result = if (metastoreRelation.hiveQlTable.isPartitioned) { - val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) - + val tablePath = new Path(new URI(relation.tableMeta.location)) + val result = if (relation.isPartitioned) { + val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { - Seq(metastoreRelation.hiveQlTable.getDataLocation) + Seq(tablePath) } else { // By convention (for example, see CatalogFileIndex), the definition of a // partitioned table's paths depends on whether that table has any actual partitions. // Partitioned tables without partitions use the location of the table's base path. // Partitioned tables with partitions use the locations of those partitions' data // locations,_omitting_ the table's base path. - val paths = metastoreRelation.getHiveQlPartitions().map { p => - new Path(p.getLocation) - } + val paths = sparkSession.sharedState.externalCatalog + .listPartitions(tableIdentifier.database, tableIdentifier.name) + .map(p => new Path(new URI(p.storage.locationUri.get))) + if (paths.isEmpty) { - Seq(metastoreRelation.hiveQlTable.getDataLocation) + Seq(tablePath) } else { paths } @@ -155,39 +154,31 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val cached = getCached( tableIdentifier, rootPaths, - metastoreRelation, metastoreSchema, fileFormatClass, - bucketSpec, Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = - metastoreRelation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong val fileIndex = { - val index = new CatalogFileIndex( - sparkSession, metastoreRelation.catalogTable, sizeInBytes) + val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { index } else { index.filterPartitions(Nil) // materialize all the partitions in memory } } - val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet - val dataSchema = - StructType(metastoreSchema - .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) - val relation = HadoopFsRelation( + val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = dataSchema, - bucketSpec = bucketSpec, - fileFormat = defaultSource, + dataSchema = relation.tableMeta.dataSchema, + // We don't support hive bucketed tables, only ones we write out. + bucketSpec = None, + fileFormat = fileFormatClass.newInstance(), options = options)(sparkSession = sparkSession) - val created = LogicalRelation(relation, - catalogTable = Some(metastoreRelation.catalogTable)) + val created = LogicalRelation(fsRelation, catalogTable = Some(relation.tableMeta)) tableRelationCache.put(tableIdentifier, created) created } @@ -195,14 +186,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } else { - val rootPath = metastoreRelation.hiveQlTable.getDataLocation + val rootPath = tablePath withTableCreationLock(tableIdentifier, { - val cached = getCached(tableIdentifier, + val cached = getCached( + tableIdentifier, Seq(rootPath), - metastoreRelation, metastoreSchema, fileFormatClass, - bucketSpec, None) val logicalRelation = cached.getOrElse { val created = @@ -210,11 +200,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Some(metastoreRelation.schema), - bucketSpec = bucketSpec, + userSpecifiedSchema = Some(metastoreSchema), + // We don't support hive bucketed tables, only ones we write out. + bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(metastoreRelation.catalogTable)) + catalogTable = Some(relation.tableMeta)) tableRelationCache.put(tableIdentifier, created) created @@ -223,7 +214,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } - result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) + result.copy(expectedOutputAttributes = Some(relation.output)) } /** @@ -231,33 +222,32 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log * data source relations for better performance. */ object ParquetConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && + private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { + relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && sessionState.convertMetastoreParquet } - private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new ParquetFileFormat() + private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) - convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "parquet") + convertToLogicalRelation(relation, options, fileFormatClass, "parquet") } override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if query.resolved && !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && shouldConvertMetastoreParquet(r) => InsertIntoTable(convertToParquetRelation(r), partition, query, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => - val parquetRelation = convertToParquetRelation(relation) - SubqueryAlias(relation.tableName, parquetRelation, None) + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && + shouldConvertMetastoreParquet(relation) => + convertToParquetRelation(relation) } } } @@ -267,31 +257,31 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log * for better performance. */ object OrcConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { - relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && + private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { + relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && sessionState.convertMetastoreOrc } - private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new OrcFileFormat() + private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[OrcFileFormat] val options = Map[String, String]() - convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc") + convertToLogicalRelation(relation, options, fileFormatClass, "orc") } override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Orc data source (yet). - if query.resolved && !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && shouldConvertMetastoreOrc(r) => InsertIntoTable(convertToOrcRelation(r), partition, query, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => - val orcRelation = convertToOrcRelation(relation) - SubqueryAlias(relation.tableName, orcRelation, None) + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && + shouldConvertMetastoreOrc(relation) => + convertToOrcRelation(relation) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 273cf85df33a2..5a08a6bc66f6b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -62,10 +62,10 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override val extendedResolutionRules = new ResolveHiveSerdeTable(sparkSession) :: new FindDataSourceTable(sparkSession) :: - new FindHiveSerdeTable(sparkSession) :: new ResolveSQLOnFile(sparkSession) :: Nil override val postHocResolutionRules = + new DetermineTableStats(sparkSession) :: catalog.ParquetConversions :: catalog.OrcConversions :: PreprocessTableCreation(sparkSession) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index f45532cc38453..624cfa206eeb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,8 +17,14 @@ package org.apache.spark.sql.hive +import java.io.IOException +import java.net.URI + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation} @@ -91,18 +97,56 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { // Infers the schema, if empty, because the schema could be determined by Hive // serde. - val catalogTable = if (query.isEmpty) { - val withSchema = HiveUtils.inferSchema(withStorage) - if (withSchema.schema.length <= 0) { + val withSchema = if (query.isEmpty) { + val inferred = HiveUtils.inferSchema(withStorage) + if (inferred.schema.length <= 0) { throw new AnalysisException("Unable to infer the schema. " + - s"The schema specification is required to create the table ${withSchema.identifier}.") + s"The schema specification is required to create the table ${inferred.identifier}.") } - withSchema + inferred } else { withStorage } - c.copy(tableDesc = catalogTable) + c.copy(tableDesc = withSchema) + } +} + +class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case relation: CatalogRelation + if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => + val table = relation.tableMeta + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. + // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys + // (see StatsSetupConst in Hive) that we can look at in the future. + // When table is external,`totalSize` is always zero, which will influence join strategy + // so when `totalSize` is zero, use `rawDataSize` instead. + val totalSize = table.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val rawDataSize = table.properties.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) + val sizeInBytes = if (totalSize.isDefined && totalSize.get > 0) { + totalSize.get + } else if (rawDataSize.isDefined && rawDataSize.get > 0) { + rawDataSize.get + } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { + try { + val hadoopConf = session.sessionState.newHadoopConf() + val tablePath = new Path(new URI(table.location)) + val fs: FileSystem = tablePath.getFileSystem(hadoopConf) + fs.getContentSummary(tablePath).getLength + } catch { + case e: IOException => + logWarning("Failed to get table size from hdfs.", e) + session.sessionState.conf.defaultSizeInBytes + } + } else { + session.sessionState.conf.defaultSizeInBytes + } + + val withStats = table.copy(stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes)))) + relation.copy(tableMeta = withStats) } } @@ -114,8 +158,9 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { */ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(table: MetastoreRelation, partSpec, query, overwrite, ifNotExists) => - InsertIntoHiveTable(table, partSpec, query, overwrite, ifNotExists) + case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) + if DDLUtils.isHiveTable(relation.tableMeta) => + InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -125,21 +170,6 @@ object HiveAnalysis extends Rule[LogicalPlan] { } } -/** - * Replaces `SimpleCatalogRelation` with [[MetastoreRelation]] if its table provider is hive. - */ -class FindHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) - if DDLUtils.isHiveTable(s.metadata) => - i.copy(table = - MetastoreRelation(s.metadata.database, s.metadata.identifier.table)(s.metadata, session)) - - case s: SimpleCatalogRelation if DDLUtils.isHiveTable(s.metadata) => - MetastoreRelation(s.metadata.database, s.metadata.identifier.table)(s.metadata, session) - } -} - private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => @@ -161,10 +191,10 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => + case PhysicalOperation(projectList, predicates, relation: CatalogRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. - val partitionKeyIds = AttributeSet(relation.partitionKeys) + val partitionKeyIds = AttributeSet(relation.partitionCols) val (pruningPredicates, otherPredicates) = predicates.partition { predicate => !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala deleted file mode 100644 index 97b120758ba45..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException - -import com.google.common.base.Objects -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable} -import org.apache.hadoop.hive.ql.plan.TableDesc - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.hive.client.HiveClientImpl -import org.apache.spark.sql.types.StructField - - -private[hive] case class MetastoreRelation( - databaseName: String, - tableName: String) - (val catalogTable: CatalogTable, - @transient private val sparkSession: SparkSession) - extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation { - - override def equals(other: Any): Boolean = other match { - case relation: MetastoreRelation => - databaseName == relation.databaseName && - tableName == relation.tableName && - output == relation.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(databaseName, tableName, output) - } - - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil - - @transient val hiveQlTable: HiveTable = HiveClientImpl.toHiveTable(catalogTable) - - @transient override def computeStats(conf: CatalystConf): Statistics = { - catalogTable.stats.map(_.toPlanStats(output)).getOrElse(Statistics( - sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) - val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) - // TODO: check if this estimate is valid for tables after partition pruning. - // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be - // relatively cheap if parameters for the table are populated into the metastore. - // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys - // (see StatsSetupConst in Hive) that we can look at in the future. - BigInt( - // When table is external,`totalSize` is always zero, which will influence join strategy - // so when `totalSize` is zero, use `rawDataSize` instead - // when `rawDataSize` is also zero, use `HiveExternalCatalog.STATISTICS_TOTAL_SIZE`, - // which is generated by analyze command. - if (totalSize != null && totalSize.toLong > 0L) { - totalSize.toLong - } else if (rawDataSize != null && rawDataSize.toLong > 0) { - rawDataSize.toLong - } else if (sparkSession.sessionState.conf.fallBackToHdfsForStatsEnabled) { - try { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val fs: FileSystem = hiveQlTable.getPath.getFileSystem(hadoopConf) - fs.getContentSummary(hiveQlTable.getPath).getLength - } catch { - case e: IOException => - logWarning("Failed to get table size from hdfs.", e) - sparkSession.sessionState.conf.defaultSizeInBytes - } - } else { - sparkSession.sessionState.conf.defaultSizeInBytes - }) - } - )) - } - - // When metastore partition pruning is turned off, we cache the list of all partitions to - // mimic the behavior of Spark < 1.5 - private lazy val allPartitions: Seq[CatalogTablePartition] = { - sparkSession.sharedState.externalCatalog.listPartitions( - catalogTable.database, - catalogTable.identifier.table) - } - - def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - catalogTable.database, - catalogTable.identifier.table, - predicates) - } else { - allPartitions - } - - rawPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) - } - - /** Only compare database and tablename, not alias. */ - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case mr: MetastoreRelation => - mr.databaseName == databaseName && mr.tableName == tableName - case _ => false - } - } - - val tableDesc = new TableDesc( - hiveQlTable.getInputFormatClass, - // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because - // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to - // substitute some output formats, e.g. substituting SequenceFileOutputFormat to - // HiveSequenceFileOutputFormat. - hiveQlTable.getOutputFormatClass, - hiveQlTable.getMetadata - ) - - implicit class SchemaAttribute(f: StructField) { - def toAttribute: AttributeReference = AttributeReference( - f.name, - f.dataType, - // Since data can be dumped in randomly with no validation, everything is nullable. - nullable = true - )(qualifier = Some(tableName)) - } - - /** PartitionKey attributes */ - val partitionKeys = catalogTable.partitionSchema.map(_.toAttribute) - - /** Non-partitionKey attributes */ - val dataColKeys = catalogTable.schema - .filter { c => !catalogTable.partitionColumnNames.contains(c.name) } - .map(_.toAttribute) - - val output = dataColKeys ++ partitionKeys - - /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o, o))) - - /** An attribute map for determining the ordinal for non-partition columns. */ - val columnOrdinals = AttributeMap(dataColKeys.zipWithIndex) - - override def inputFiles: Array[String] = { - val partLocations = allPartitions - .flatMap(_.storage.locationUri) - .toArray - if (partLocations.nonEmpty) { - partLocations - } else { - Array( - catalogTable.storage.locationUri.getOrElse( - sys.error(s"Could not get the location of ${catalogTable.qualifiedName}."))) - } - } - - override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName)(catalogTable, sparkSession) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b4b63032ab261..16c1103dd1ea3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -61,19 +61,22 @@ private[hive] sealed trait TableReader { private[hive] class HadoopTableReader( @transient private val attributes: Seq[Attribute], - @transient private val relation: MetastoreRelation, + @transient private val partitionKeys: Seq[Attribute], + @transient private val tableDesc: TableDesc, @transient private val sparkSession: SparkSession, hadoopConf: Configuration) extends TableReader with Logging { - // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". - // https://hadoop.apache.org/docs/r1.0.4/mapred-default.html + // Hadoop honors "mapreduce.job.maps" as hint, + // but will ignore when mapreduce.jobtracker.address is "local". + // https://hadoop.apache.org/docs/r2.6.5/hadoop-mapreduce-client/hadoop-mapreduce-client-core/ + // mapred-default.xml // // In order keep consistency with Hive, we will let it be 0 in local mode also. private val _minSplitsPerRDD = if (sparkSession.sparkContext.isLocal) { 0 // will splitted based on block by default. } else { - math.max(hadoopConf.getInt("mapred.map.tasks", 1), + math.max(hadoopConf.getInt("mapreduce.job.maps", 1), sparkSession.sparkContext.defaultMinPartitions) } @@ -86,7 +89,7 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], + Utils.classForName(tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -108,7 +111,7 @@ class HadoopTableReader( // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = relation.tableDesc + val localTableDesc = tableDesc val broadcastedHadoopConf = _broadcastedHadoopConf val tablePath = hiveTable.getPath @@ -117,7 +120,7 @@ class HadoopTableReader( // logDebug("Table input: %s".format(tablePath)) val ifc = hiveTable.getInputFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val hadoopRDD = createHadoopRdd(localTableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) @@ -125,7 +128,7 @@ class HadoopTableReader( val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value val deserializer = deserializerClass.newInstance() - deserializer.initialize(hconf, tableDesc.getProperties) + deserializer.initialize(hconf, localTableDesc.getProperties) HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } @@ -210,8 +213,6 @@ class HadoopTableReader( partCols.map(col => new String(partSpec.get(col))).toArray } - // Create local references so that the outer object isn't serialized. - val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHadoopConf val localDeserializer = partDeserializer val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) @@ -220,12 +221,12 @@ class HadoopTableReader( // Attached indices indicate the position of each attribute in the output schema. val (partitionKeyAttrs, nonPartitionKeyAttrs) = attributes.zipWithIndex.partition { case (attr, _) => - relation.partitionKeys.contains(attr) + partitionKeys.contains(attr) } def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => - val partOrdinal = relation.partitionKeys.indexOf(attr) + val partOrdinal = partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } @@ -233,9 +234,11 @@ class HadoopTableReader( // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val tableProperties = relation.tableDesc.getProperties + val tableProperties = tableDesc.getProperties - createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => + // Create local references so that the outer object isn't serialized. + val localTableDesc = tableDesc + createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema @@ -249,8 +252,8 @@ class HadoopTableReader( } deserializer.initialize(hconf, props) // get the table deserializer - val tableSerDe = tableDesc.getDeserializerClass.newInstance() - tableSerDe.initialize(hconf, tableDesc.getProperties) + val tableSerDe = localTableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c326ac4cc1a53..8f98c8f447037 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -96,6 +96,7 @@ private[hive] class HiveClientImpl( case hive.v1_0 => new Shim_v1_0() case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() + case hive.v2_0 => new Shim_v2_0() } // Create an internal session state for this HiveClientImpl. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 9fe1c76d3325d..7280748361d60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -833,3 +833,77 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { } } + +private[client] class Shim_v2_0 extends Shim_v1_2 { + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index d2487a2c034c0..6f69a4adf29d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -94,6 +94,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 + case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 4e2193b6abc3f..790ad74e6639e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive /** Support for interacting with different versions of the HiveMetastoreClient */ package object client { - private[hive] abstract class HiveVersion( + private[hive] sealed abstract class HiveVersion( val fullVersion: String, val extraDeps: Seq[String] = Nil, val exclusions: Seq[String] = Nil) @@ -62,6 +62,12 @@ package object client { "org.pentaho:pentaho-aggdesigner-algorithm", "net.hydromatic:linq4j", "net.hydromatic:quidem")) + + case object v2_0 extends HiveVersion("2.0.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 140c352fa6f8d..28f074849c0f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption @@ -29,10 +30,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -46,12 +49,12 @@ import org.apache.spark.util.Utils private[hive] case class HiveTableScanExec( requestedAttributes: Seq[Attribute], - relation: MetastoreRelation, + relation: CatalogRelation, partitionPruningPred: Seq[Expression])( @transient private val sparkSession: SparkSession) extends LeafExecNode { - require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + require(partitionPruningPred.isEmpty || relation.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") override lazy val metrics = Map( @@ -60,42 +63,54 @@ case class HiveTableScanExec( override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(partitionPruningPred.flatMap(_.references)) - // Retrieve the original attributes based on expression ID so that capitalization matches. - val attributes = requestedAttributes.map(relation.attributeMap) + private val originalAttributes = AttributeMap(relation.output.map(a => a -> a)) + + override val output: Seq[Attribute] = { + // Retrieve the original attributes based on expression ID so that capitalization matches. + requestedAttributes.map(originalAttributes) + } // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") - BindReferences.bindReference(pred, relation.partitionKeys) + BindReferences.bindReference(pred, relation.partitionCols) } // Create a local copy of hadoopConf,so that scan specific modifications should not impact // other queries - @transient - private[this] val hadoopConf = sparkSession.sessionState.newHadoopConf() + @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() + + @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata) // append columns ids and names before broadcast addColumnMetadataToConf(hadoopConf) - @transient - private[this] val hadoopReader = - new HadoopTableReader(attributes, relation, sparkSession, hadoopConf) + @transient private val hadoopReader = new HadoopTableReader( + output, + relation.partitionCols, + tableDesc, + sparkSession, + hadoopConf) - private[this] def castFromString(value: String, dataType: DataType) = { + private def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) } private def addColumnMetadataToConf(hiveConf: Configuration) { // Specifies needed column IDs for those non-partitioning columns. - val neededColumnIDs = attributes.flatMap(relation.columnOrdinals.get).map(o => o: Integer) + val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) + val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) - HiveShim.appendReadColumns(hiveConf, neededColumnIDs, attributes.map(_.name)) + HiveShim.appendReadColumns(hiveConf, neededColumnIDs, output.map(_.name)) - val tableDesc = relation.tableDesc val deserializer = tableDesc.getDeserializerClass.newInstance deserializer.initialize(hiveConf, tableDesc.getProperties) @@ -113,7 +128,7 @@ case class HiveTableScanExec( .mkString(",") hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataColKeys.map(_.name).mkString(",")) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataCols.map(_.name).mkString(",")) } /** @@ -126,7 +141,7 @@ case class HiveTableScanExec( boundPruningPred match { case None => partitions case Some(shouldKeep) => partitions.filter { part => - val dataTypes = relation.partitionKeys.map(_.dataType) + val dataTypes = relation.partitionCols.map(_.dataType) val castedValues = part.getValues.asScala.zip(dataTypes) .map { case (value, dataType) => castFromString(value, dataType) } @@ -138,27 +153,36 @@ case class HiveTableScanExec( } } + // exposed for tests + @transient lazy val rawPartitions = { + val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sharedState.externalCatalog.listPartitionsByFilter( + relation.tableMeta.database, + relation.tableMeta.identifier.table, + normalizedFilters, + sparkSession.sessionState.conf.sessionLocalTimeZone) + } else { + sparkSession.sharedState.externalCatalog.listPartitions( + relation.tableMeta.database, + relation.tableMeta.identifier.table) + } + prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) + } + protected override def doExecute(): RDD[InternalRow] = { // Using dummyCallSite, as getCallSite can turn out to be expensive with // with multiple partitions. - val rdd = if (!relation.hiveQlTable.isPartitioned) { + val rdd = if (!relation.isPartitioned) { Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) + hadoopReader.makeRDDForTable(hiveQlTable) } } else { - // The attribute name of predicate could be different than the one in schema in case of - // case insensitive, we should change them to match the one in schema, so we do not need to - // worry about case sensitivity anymore. - val normalizedFilters = partitionPruningPred.map { e => - e transform { - case a: AttributeReference => - a.withName(relation.output.find(_.semanticEquals(a)).get.name) - } - } - Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(normalizedFilters))) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(rawPartitions)) } } val numOutputRows = longMetric("numOutputRows") @@ -174,8 +198,6 @@ case class HiveTableScanExec( } } - override def output: Seq[Attribute] = attributes - override def sameResult(plan: SparkPlan): Boolean = plan match { case other: HiveTableScanExec => val thisPredicates = partitionPruningPred.map(cleanExpression) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 5d5688ecb36b4..3c57ee4c8b8f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -29,16 +29,18 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.hadoop.hive.ql.ErrorMsg +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.client.HiveVersion +import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion} import org.apache.spark.SparkException @@ -52,9 +54,7 @@ import org.apache.spark.SparkException * In the future we should converge the write path for Hive with the normal data source write path, * as defined in `org.apache.spark.sql.execution.datasources.FileFormatWriter`. * - * @param table the logical plan representing the table. In the future this should be a - * `org.apache.spark.sql.catalyst.catalog.CatalogTable` once we converge Hive tables - * and data source tables. + * @param table the metadata of the table. * @param partition a map from the partition key to the partition value (optional). If the partition * value is optional, dynamic partition insert will be performed. * As an example, `INSERT INTO tbl PARTITION (a=1, b=2) AS ...` would have @@ -74,7 +74,7 @@ import org.apache.spark.SparkException * @param ifNotExists If true, only write if the table or partition does not exist. */ case class InsertIntoHiveTable( - table: MetastoreRelation, + table: CatalogTable, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, @@ -148,9 +148,16 @@ case class InsertIntoHiveTable( // We have to follow the Hive behavior here, to avoid troubles. For example, if we create // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. - if (hiveVersion == v12 || hiveVersion == v13 || hiveVersion == v14 || hiveVersion == v1_0) { + val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + + // Ensure all the supported versions are considered here. + assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == + allSupportedHiveVersions) + + if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) { oldVersionExternalTempPath(path, hadoopConf, scratchDir) - } else if (hiveVersion == v1_1 || hiveVersion == v1_2) { + } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) { newVersionExternalTempPath(path, hadoopConf, stagingDir) } else { throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion) @@ -218,23 +225,35 @@ case class InsertIntoHiveTable( val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") + val hiveQlTable = HiveClientImpl.toHiveTable(table) // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. - val tableDesc = table.tableDesc - val tableLocation = table.hiveQlTable.getDataLocation + val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because + // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to + // substitute some output formats, e.g. substituting SequenceFileOutputFormat to + // HiveSequenceFileOutputFormat. + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + val tableLocation = hiveQlTable.getDataLocation val tmpLocation = getExternalTmpPath(tableLocation, hiveVersion, hadoopConf, stagingDir, scratchDir) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean if (isCompressed) { - // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", - // and "mapred.output.compression.type" have no impact on ORC because it uses table properties - // to store compression information. - hadoopConf.set("mapred.output.compress", "true") + // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact on ORC because it uses table properties to store compression information. + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(hadoopConf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(hadoopConf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.codec")) + fileSinkConf.setCompressType(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.type")) } val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -251,9 +270,9 @@ case class InsertIntoHiveTable( // By this time, the partition map must match the table's partition columns if (partitionColumnNames.toSet != partition.keySet) { throw new SparkException( - s"""Requested partitioning does not match the ${table.tableName} table: + s"""Requested partitioning does not match the ${table.identifier.table} table: |Requested partitions: ${partition.keys.mkString(",")} - |Table partitions: ${table.partitionKeys.map(_.name).mkString(",")}""".stripMargin) + |Table partitions: ${table.partitionColumnNames.mkString(",")}""".stripMargin) } // Validate partition spec if there exist any dynamic partitions @@ -304,8 +323,8 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { if (numDynamicPartitions > 0) { externalCatalog.loadDynamicPartitions( - db = table.catalogTable.database, - table = table.catalogTable.identifier.table, + db = table.database, + table = table.identifier.table, tmpLocation.toString, partitionSpec, overwrite, @@ -317,8 +336,8 @@ case class InsertIntoHiveTable( // scalastyle:on val oldPart = externalCatalog.getPartitionOption( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, partitionSpec) var doHiveOverwrite = overwrite @@ -347,8 +366,8 @@ case class InsertIntoHiveTable( // which is currently considered as a Hive native command. val inheritTableSpecs = true externalCatalog.loadPartition( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, tmpLocation.toString, partitionSpec, isOverwrite = doHiveOverwrite, @@ -358,8 +377,8 @@ case class InsertIntoHiveTable( } } else { externalCatalog.loadTable( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, tmpLocation.toString, // TODO: URI overwrite, isSrcLocal = false) @@ -375,8 +394,8 @@ case class InsertIntoHiveTable( } // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(table) - sparkSession.sessionState.catalog.refreshTable(table.catalogTable.identifier) + sparkSession.catalog.uncacheTable(table.qualifiedName) + sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index fd139119472db..efc2f0098454b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -439,7 +439,7 @@ private[hive] class TestHiveSparkSession( foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. - sessionState.conf.setConfString("fs.default.name", new File(".").toURI.toString) + sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. sessionState.metadataHive.runSqlHive("RESET") diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index f664d5a4cdada..aefc9cc77da88 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -26,7 +26,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; @@ -35,7 +34,6 @@ import org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { - private transient JavaSparkContext sc; private transient SQLContext hc; Dataset df; @@ -50,13 +48,11 @@ private static void checkAnswer(Dataset actual, List expected) { @Before public void setUp() throws IOException { hc = TestHive$.MODULE$; - sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } - df = hc.read().json(sc.parallelize(jsonObjects)); + df = hc.read().json(hc.createDataset(jsonObjects, Encoders.STRING())); df.createOrReplaceTempView("window_table"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 061c7431a6362..0b157a45e6e05 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -31,9 +31,9 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -81,8 +81,8 @@ public void setUp() throws IOException { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + Dataset ds = sqlContext.createDataset(jsonObjects, Encoders.STRING()); + df = sqlContext.read().json(ds); df.createOrReplaceTempView("jsonTable"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java index 6adb1657bf25b..8211cbf16f7bf 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java @@ -25,6 +25,7 @@ * UDF that returns a raw (non-parameterized) java List. */ public class UDFRawList extends UDF { + @SuppressWarnings("rawtypes") public List evaluate(Object o) { return Collections.singletonList("data1"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java index 4731b6eee85cd..58c81f9945d7e 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java @@ -25,6 +25,7 @@ * UDF that returns a raw (non-parameterized) java Map. */ public class UDFRawMap extends UDF { + @SuppressWarnings("rawtypes") public Map evaluate(Object o) { return Collections.singletonMap("a", "1"); } diff --git a/sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/combine1-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/combine1-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/combine1-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/combine1-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/combine1-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/combine1-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/combine1-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/combine1-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/combine1-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/combine1-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/combine1-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/combine1-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/combine2-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/combine2-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/combine2-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/combine2-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/combine2-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/combine2-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/combine2-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/combine2-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/combine2-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/combine2-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/combine2-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/combine2-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/groupby1-3-d57ed4bbfee1ffaffaeba0a4be84c31d b/sql/hive/src/test/resources/golden/groupby1-3-c8478dac3497697b4375ee35118a5c3e similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1-3-d57ed4bbfee1ffaffaeba0a4be84c31d rename to sql/hive/src/test/resources/golden/groupby1-3-c8478dac3497697b4375ee35118a5c3e diff --git a/sql/hive/src/test/resources/golden/groupby1-5-dd7bf298b8c921355edd8665c6b0c168 b/sql/hive/src/test/resources/golden/groupby1-5-c9cee6382b64bd3d71177527961b8be2 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1-5-dd7bf298b8c921355edd8665c6b0c168 rename to sql/hive/src/test/resources/golden/groupby1-5-c9cee6382b64bd3d71177527961b8be2 diff --git a/sql/hive/src/test/resources/golden/groupby1_limit-0-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_limit-0-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_limit-0-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_limit-0-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_limit-0-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_limit-0-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_limit-0-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_limit-0-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map-3-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map-3-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map-3-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map-3-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_noskew-3-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_noskew-3-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_noskew-3-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_noskew-3-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby_map_ppr-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby_map_ppr-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby_map_ppr-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby_map_ppr-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/input12_hadoop20-0-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/input12_hadoop20-0-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/input12_hadoop20-0-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/input12_hadoop20-0-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-0-68975193b30cb34102b380e647d8d5f4 b/sql/hive/src/test/resources/golden/input_testsequencefile-0-dd959af1968381d0ed90178d349b01a7 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-0-68975193b30cb34102b380e647d8d5f4 rename to sql/hive/src/test/resources/golden/input_testsequencefile-0-dd959af1968381d0ed90178d349b01a7 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-1-1c0f3be2d837dee49312e0a80440447e b/sql/hive/src/test/resources/golden/input_testsequencefile-1-ddbb8d5e5dc0988bda96ac2b4aec8f94 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-1-1c0f3be2d837dee49312e0a80440447e rename to sql/hive/src/test/resources/golden/input_testsequencefile-1-ddbb8d5e5dc0988bda96ac2b4aec8f94 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-5-3708198aac609695b22e19e89306034c b/sql/hive/src/test/resources/golden/input_testsequencefile-5-25715870c569b0f8c3d483e3a38b3199 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-5-3708198aac609695b22e19e89306034c rename to sql/hive/src/test/resources/golden/input_testsequencefile-5-25715870c569b0f8c3d483e3a38b3199 diff --git a/sql/hive/src/test/resources/golden/join14_hadoop20-1-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/join14_hadoop20-1-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/join14_hadoop20-1-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/join14_hadoop20-1-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-6b9861b999092f1ea4fa1fd27a666af6 similarity index 100% rename from sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c rename to sql/hive/src/test/resources/golden/leftsemijoin_mr-7-6b9861b999092f1ea4fa1fd27a666af6 diff --git a/sql/hive/src/test/resources/golden/merge2-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/merge2-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/merge2-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/merge2-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/merge2-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/merge2-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/merge2-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/merge2-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/merge2-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/merge2-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/merge2-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/merge2-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/parallel-0-23a4feaede17467a8cc26e4d86ec30f9 b/sql/hive/src/test/resources/golden/parallel-0-6dc30e2de057022e63bd2a645fbec4c2 similarity index 100% rename from sql/hive/src/test/resources/golden/parallel-0-23a4feaede17467a8cc26e4d86ec30f9 rename to sql/hive/src/test/resources/golden/parallel-0-6dc30e2de057022e63bd2a645fbec4c2 diff --git a/sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-3708198aac609695b22e19e89306034c b/sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-25715870c569b0f8c3d483e3a38b3199 similarity index 100% rename from sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-3708198aac609695b22e19e89306034c rename to sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-25715870c569b0f8c3d483e3a38b3199 diff --git a/sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-68975193b30cb34102b380e647d8d5f4 b/sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-dd959af1968381d0ed90178d349b01a7 similarity index 100% rename from sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-68975193b30cb34102b380e647d8d5f4 rename to sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-dd959af1968381d0ed90178d349b01a7 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q index 235b7c1b3fcd2..6a9a20f3207b5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q @@ -5,7 +5,7 @@ set hive.auto.convert.join = true; CREATE TABLE dest1(c1 INT, c2 STRING) STORED AS TEXTFILE; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; explain diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q index 877f8a50a0e35..87f6eca4dd4e3 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q @@ -4,7 +4,7 @@ set hive.enforce.sorting = true; set hive.exec.reducers.max = 1; set hive.merge.mapfiles = true; set hive.merge.mapredfiles = true; -set mapred.reduce.tasks = 2; +set mapreduce.job.reduces = 2; -- Tests that when a multi insert inserts into a bucketed table and a table which is not bucketed -- the bucketed table is not merged and the table which is not bucketed is diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q index 37ae6cc7adeae..84fe3919d7a68 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q @@ -1,6 +1,6 @@ set hive.enforce.bucketing = true; set hive.exec.mode.local.auto=false; -set mapred.reduce.tasks = 10; +set mapreduce.job.reduces = 10; -- This test sets number of mapred tasks to 10 for a database with 50 buckets, -- and uses a post-hook to confirm that 10 tasks were created diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q index d2e12e82d4a26..ae72f98fa424c 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q @@ -1,5 +1,5 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat; -set mapred.min.split.size = 64; +set mapreduce.input.fileinputformat.split.minsize = 64; CREATE TABLE T1(name STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q index 86abf0996057b..5ecfc21724788 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q @@ -1,11 +1,11 @@ set hive.exec.compress.output = true; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -set mapred.output.compression.codec=org.apache.hadoop.io.compress.GzipCodec; +set mapreduce.output.fileoutputformat.compress.codec=org.apache.hadoop.io.compress.GzipCodec; create table combine1_1(key string, value string) stored as textfile; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q index cfd9856f0868a..acd0dd5e5bc96 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q @@ -1,10 +1,10 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; @@ -18,7 +18,7 @@ set hive.merge.smallfiles.avgsize=0; create table combine2(key string) partitioned by (value string); -- EXCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=256 and hive.merge.smallfiles.avgsize=0 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=256 and hive.merge.smallfiles.avgsize=0 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q index 8f9a59d497536..597d3ae479b97 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q @@ -1,10 +1,10 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; @@ -17,7 +17,7 @@ set hive.merge.smallfiles.avgsize=0; create table combine2(key string) partitioned by (value string); -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=256 and hive.merge.smallfiles.avgsize=0 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=256 and hive.merge.smallfiles.avgsize=0 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q index f6090bb99b29a..4f7174a1b6365 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q @@ -1,8 +1,8 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q index c9afc91bb4561..35dd442027b45 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q @@ -1,9 +1,9 @@ set hive.exec.compress.output = true; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; drop table combine_3_srcpart_seq_rc; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q index f348e5902263a..5e51d11864dd2 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q @@ -1,4 +1,4 @@ -set fs.default.name=invalidscheme:///; +set fs.defaultFS=invalidscheme:///; CREATE TABLE table1 (a STRING, b STRING) STORED AS TEXTFILE; DESCRIBE table1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q index f39689de03a55..979c9072303c4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q @@ -49,7 +49,7 @@ describe formatted nzhang_CTAS4; explain extended create table nzhang_ctas5 row format delimited fields terminated by ',' lines terminated by '\012' stored as textfile as select key, value from src sort by key, value limit 10; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; create table nzhang_ctas5 row format delimited fields terminated by ',' lines terminated by '\012' stored as textfile as select key, value from src sort by key, value limit 10; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q index 1275eab281f42..0d75857e54e54 100755 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q @@ -3,12 +3,12 @@ set hive.groupby.skewindata=true; CREATE TABLE dest_g1(key INT, value DOUBLE) STORED AS TEXTFILE; -set fs.default.name=invalidscheme:///; +set fs.defaultFS=invalidscheme:///; EXPLAIN FROM src INSERT OVERWRITE TABLE dest_g1 SELECT src.key, sum(substr(src.value,5)) GROUP BY src.key; -set fs.default.name=file:///; +set fs.defaultFS=file:///; FROM src INSERT OVERWRITE TABLE dest_g1 SELECT src.key, sum(substr(src.value,5)) GROUP BY src.key; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q index 55133332a8662..bbb2859a9d452 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q @@ -1,4 +1,4 @@ -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q index dde37dfd47145..7883d948d0672 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q index f346cb7e90147..a5ac3762ce798 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q index c587b5f658f68..6341eefb50434 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q index 30499248cac15..df4693446d6c8 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q @@ -1,4 +1,4 @@ -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; EXPLAIN SELECT src.key, sum(substr(src.value,5)) FROM src GROUP BY src.key ORDER BY src.key LIMIT 5; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q index 794ec758e9edb..7b6e175c2df0a 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q index 55d1a34b3c921..3aeae0d5c33d6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING, c3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q index 39a2a178e3a5e..998156d05f99a 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q index 6d7cb61e2d44a..fab4f5d097f16 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g2(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q index b2450c9ea04e1..9ef556cdc5834 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g2(key STRING, c1 INT, c2 STRING, c3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q index 7ecc71dfab64a..36ba5d89c0f72 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q index 50243beca9efa..6f0a9635a284f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE, c10 DOUBLE, c11 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q index 07d10c2d741d8..64a49e2525edf 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q index d33f12c5744e9..4fd98efd6ef41 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q index 86d8986f1df7d..85ee8ac43e526 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE, c10 DOUBLE, c11 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q index 8ecce23eb8321..d71721875bbff 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q index eb2001c6b21b0..d1ecba143d622 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q index a1ebf90aadfea..63530c262c147 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q index 4fd6445d7927c..4418bbffec7ab 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q index eccd45dd5b422..ef20dacf05992 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q index e96568b398d87..17b322b890ff6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q index ced122fae3f50..bef0eeee0e898 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q index 0d3727b052858..ee93b218ac788 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q index 466c13222f29f..72fff08decf0f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q index 2b8c5db41ea92..75149b140415f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q @@ -1,7 +1,7 @@ set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q index 5895ed4599849..7c7829aac2d6e 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q index ee6d7bf83084e..905986d417dff 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q index 8c2308e5d75c3..1f63453672a40 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.multigroupby.singlereducer=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q index e673cc61622c8..2ce57e98072f2 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q index 0252e993363aa..9def7d64721eb 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q index b5e1f63a45257..788bc683697d6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q index da85504ca18c6..17885c56b3f1f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q index 4a199365cf968..9cb98aa909e1b 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q index cb3ee82918611..841df75af18bb 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING, C3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q index 7401a9ca1d9bd..cdf4bb1cac9dc 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q @@ -248,7 +248,7 @@ SELECT * FROM outputTbl4 ORDER BY key1, key2, key3; set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, cnt INT); CREATE TABLE DEST2(key INT, val STRING, cnt INT); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q index db0faa04da0ec..1c23fad76eff7 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q @@ -249,7 +249,7 @@ SELECT * FROM outputTbl4 ORDER BY key1, key2, key3; set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, cnt INT); CREATE TABLE DEST2(key INT, val STRING, cnt INT); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q index 94ba14802f015..996c9d99f0b98 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q @@ -5,7 +5,7 @@ ALTER TABLE vcsc ADD partition (ds='dummy') location '${system:test.tmp.dir}/Ver set hive.exec.pre.hooks=org.apache.hadoop.hive.ql.hooks.VerifyContentSummaryCacheHook; SELECT a.c, b.c FROM vcsc a JOIN vcsc b ON a.ds = 'dummy' AND b.ds = 'dummy' AND a.c = b.c; -set mapred.job.tracker=local; +set mapreduce.jobtracker.address=local; set hive.exec.pre.hooks = ; set hive.exec.post.hooks=org.apache.hadoop.hive.ql.hooks.VerifyContentSummaryCacheHook; SELECT a.c, b.c FROM vcsc a JOIN vcsc b ON a.ds = 'dummy' AND b.ds = 'dummy' AND a.c = b.c; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q index 728b8cc4a9497..5d3c6c43c6408 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q @@ -63,7 +63,7 @@ set hive.merge.mapredfiles=true; set hive.merge.smallfiles.avgsize=200; set hive.exec.compress.output=false; set hive.exec.dynamic.partition=true; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; -- Tests dynamic partitions where bucketing/sorting can be inferred, but some partitions are -- merged and some are moved. Currently neither should be bucketed or sorted, in the future, diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q index 41c1a13980cfe..aa49b0dc64c46 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q @@ -1,7 +1,7 @@ set hive.exec.infer.bucket.sort=true; set hive.exec.infer.bucket.sort.num.buckets.power.two=true; set hive.merge.mapredfiles=true; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; -- This tests inferring how data is bucketed/sorted from the operators in the reducer -- and populating that information in partitions' metadata. In particular, those cases diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q index 2255bdb34913d..3a454f77bc4dd 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q @@ -1,7 +1,7 @@ set hive.exec.infer.bucket.sort=true; set hive.merge.mapfiles=false; set hive.merge.mapredfiles=false; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; CREATE TABLE test_table (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q index 318cd378db137..31e99e8d94644 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q @@ -1,4 +1,4 @@ -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q index 29e9fae1da9e3..362c164176a9c 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q @@ -15,7 +15,7 @@ select key, value from src; set hive.test.mode=true; set hive.mapred.mode=strict; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; explain @@ -24,7 +24,7 @@ select count(1) from t1 join t2 on t1.key=t2.key where t1.ds='1' and t2.ds='1'; select count(1) from t1 join t2 on t1.key=t2.key where t1.ds='1' and t2.ds='1'; set hive.test.mode=false; -set mapred.job.tracker; +set mapreduce.jobtracker.address; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q index d9926888cef9c..2b16c5cd08649 100755 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q @@ -1,5 +1,5 @@ -set mapred.output.compress=true; -set mapred.output.compression.type=BLOCK; +set mapreduce.output.fileoutputformat.compress=true; +set mapreduce.output.fileoutputformat.compress.type=BLOCK; CREATE TABLE dest4_sequencefile(key INT, value STRING) STORED AS SEQUENCEFILE; @@ -10,5 +10,5 @@ INSERT OVERWRITE TABLE dest4_sequencefile SELECT src.key, src.value; FROM src INSERT OVERWRITE TABLE dest4_sequencefile SELECT src.key, src.value; -set mapred.output.compress=false; +set mapreduce.output.fileoutputformat.compress=false; SELECT dest4_sequencefile.* FROM dest4_sequencefile; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q index a12ef1afb055f..b3d75b63bd400 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q @@ -2,7 +2,7 @@ CREATE TABLE dest1(c1 INT, c2 STRING) STORED AS TEXTFILE; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; EXPLAIN diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q index c9ebe0e8fad12..d98247b63d34f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q @@ -9,7 +9,7 @@ SELECT * FROM T1; SELECT * FROM T2; set hive.auto.convert.join=false; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; set hive.join.emit.interval=100; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q index 8b77bd2fe19ba..9189e7c0d1af0 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q @@ -1,9 +1,9 @@ set hive.merge.mapfiles=true; set hive.merge.mapredfiles=true; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; create table test1(key int, val int); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q index 872692567b37d..dcb2a853bae59 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q @@ -1,5 +1,5 @@ -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE orc_createas1a; DROP TABLE orc_createas1b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q index 1f5f54ae19ee8..93f8f519cf21e 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q index c34be867e484f..3a74de82a4725 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q index a93590eacca01..82f68a9ae56b3 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q index 0fecc664e46db..99f58cd73f79f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q index 54eb23e776b88..9aa868f9d2f07 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q @@ -3,8 +3,8 @@ create table orc_split_elim (userid bigint, string1 string, subtype double, deci load data local inpath '../../data/files/orc_split_elim.orc' into table orc_split_elim; SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; SET hive.optimize.index.filter=false; -- The above table will have 5 splits with the followings stats diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q index 03edeaadeef51..3ac60306551e8 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q @@ -1,4 +1,4 @@ -set mapred.job.name='test_parallel'; +set mapreduce.job.name='test_parallel'; set hive.exec.parallel=true; set hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q index 73c3940644844..777771f227634 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q @@ -2,7 +2,7 @@ create table src5 (key string, value string); load data local inpath '../../data/files/kv5.txt' into table src5; load data local inpath '../../data/files/kv5.txt' into table src5; -set mapred.reduce.tasks = 4; +set mapreduce.job.reduces = 4; set hive.optimize.sampling.orderby=true; set hive.optimize.sampling.orderby.percent=0.66f; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q index f36203724c15f..14e13c56b1dba 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_createas1a; DROP TABLE rcfile_createas1b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q index 7f55d10bd6458..43a15a06f8709 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q @@ -10,7 +10,7 @@ SELECT key, value FROM rcfileTableLazyDecompress where key > 238 and key < 400 O SELECT key, count(1) FROM rcfileTableLazyDecompress where key > 238 group by key ORDER BY key ASC; -set mapred.output.compress=true; +set mapreduce.output.fileoutputformat.compress=true; set hive.exec.compress.output=true; FROM src @@ -22,6 +22,6 @@ SELECT key, value FROM rcfileTableLazyDecompress where key > 238 and key < 400 O SELECT key, count(1) FROM rcfileTableLazyDecompress where key > 238 group by key ORDER BY key ASC; -set mapred.output.compress=false; +set mapreduce.output.fileoutputformat.compress=false; set hive.exec.compress.output=false; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q index 1f6f1bd251c25..25071579cb049 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=false; set hive.exec.dynamic.partition=true; -set mapred.max.split.size=100; +set mapreduce.input.fileinputformat.split.maxsize=100; set mapref.min.split.size=1; DROP TABLE rcfile_merge1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q index 215d5ebc4a25b..15ffb90bf6271 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q @@ -1,7 +1,7 @@ set hive.merge.rcfile.block.level=true; set hive.exec.dynamic.partition=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge2a; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q index 39fbd2564664b..787ab4a8d7fa5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge3a; DROP TABLE rcfile_merge3b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q index fe6df28566cf0..77ac381c65bb7 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge3a; DROP TABLE rcfile_merge3b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q index 12f2bcd46ec8f..bf12ba5ed8e61 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q @@ -1,8 +1,8 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.exec.mode.local.auto=true; set hive.merge.smallfiles.avgsize=1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q index 484e1fa617d8a..5d1bd184d2ad2 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q @@ -1,15 +1,15 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.exec.mode.local.auto=true; set hive.merge.smallfiles.avgsize=1; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=300 and hive.merge.smallfiles.avgsize=1 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=300 and hive.merge.smallfiles.avgsize=1 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a @@ -25,7 +25,7 @@ create table sih_src as select key, value from sih_i_part order by key, value; create table sih_src2 as select key, value from sih_src order by key, value; set hive.exec.post.hooks = org.apache.hadoop.hive.ql.hooks.VerifyIsLocalModeHook ; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto.input.files.max=1; -- Sample split, running locally limited by num tasks diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q index 952eaf72f10c1..eb774f15829b3 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q @@ -1,14 +1,14 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.merge.smallfiles.avgsize=1; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20) --- This test sets mapred.max.split.size=300 and hive.merge.smallfiles.avgsize=1 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=300 and hive.merge.smallfiles.avgsize=1 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a @@ -72,10 +72,10 @@ select t1.key as k1, t2.key as k from ss_src1 tablesample(80 percent) t1 full ou -- shrink last split explain select count(1) from ss_src2 tablesample(1 percent); -set mapred.max.split.size=300000; -set mapred.min.split.size=300000; -set mapred.min.split.size.per.node=300000; -set mapred.min.split.size.per.rack=300000; +set mapreduce.input.fileinputformat.split.maxsize=300000; +set mapreduce.input.fileinputformat.split.minsize=300000; +set mapreduce.input.fileinputformat.split.minsize.per.node=300000; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300000; select count(1) from ss_src2 tablesample(1 percent); select count(1) from ss_src2 tablesample(50 percent); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q index cdf92e44cf676..caf359c9e6b43 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q @@ -2,13 +2,13 @@ set datanucleus.cache.collections=false; set hive.stats.autogather=false; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20,0.20S) --- This test uses mapred.max.split.size/mapred.max.split.size for controlling +-- This test uses mapreduce.input.fileinputformat.split.maxsize/mapred.max.split.size for controlling -- number of input splits, which is not effective in hive 0.20. -- stats_partscan_1_23.q is the same test with this but has different result. diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q index 1e5f360b20cbb..07694891fd6ff 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q @@ -2,13 +2,13 @@ set datanucleus.cache.collections=false; set hive.stats.autogather=false; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.23) --- This test uses mapred.max.split.size/mapred.max.split.size for controlling +-- This test uses mapreduce.input.fileinputformat.split.maxsize/mapred.max.split.size for controlling -- number of input splits. -- stats_partscan_1.q is the same test with this but has different result. diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q index f065385688a1d..5b5d669a7c12d 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q @@ -1,6 +1,6 @@ CREATE TABLE kafka (contents STRING); LOAD DATA LOCAL INPATH '../../data/files/text-en.txt' INTO TABLE kafka; -set mapred.reduce.tasks=1; +set mapreduce.job.reduces=1; set hive.exec.reducers.max=1; SELECT context_ngrams(sentences(lower(contents)), array(null), 100, 1000).estfrequency FROM kafka; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q index 6a2fde52e42f6..39e6e30ae6945 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q @@ -1,6 +1,6 @@ CREATE TABLE kafka (contents STRING); LOAD DATA LOCAL INPATH '../../data/files/text-en.txt' INTO TABLE kafka; -set mapred.reduce.tasks=1; +set mapreduce.job.reduces=1; set hive.exec.reducers.max=1; SELECT ngrams(sentences(lower(contents)), 1, 100, 1000).estfrequency FROM kafka; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3871b3d785882..8ccc2b7527f24 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -204,13 +204,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) // Append new data. table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) - // We are still using the old data. assertCached(table("refreshTable")) - checkAnswer( - table("refreshTable"), - table("src").collect()) - // Refresh the table. - sql("REFRESH TABLE refreshTable") + // We are using the new data. assertCached(table("refreshTable")) checkAnswer( @@ -249,13 +244,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) // Append new data. table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) - // We are still using the old data. assertCached(table("refreshTable")) - checkAnswer( - table("refreshTable"), - table("src").collect()) - // Refresh the table. - sql(s"REFRESH ${tempPath.toString}") + // We are using the new data. assertCached(table("refreshTable")) checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index a60c210b04c8a..4349f1aa23be0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -52,7 +52,7 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { test("list partitions by filter") { val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1)) + val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") assert(selectedPartitions.length == 1) assert(selectedPartitions.head.spec == part1.spec) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 71ce5a7c4a15a..d6999af84eac0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -284,19 +284,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } - test("Reject partitioning that does not match table") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) - .toDF("id", "data", "part") - - intercept[AnalysisException] { - // cannot partition by 2 fields when there is only one in the table definition - data.write.partitionBy("part", "data").insertInto("partitioned") - } - } - } - test("Test partition mode = strict") { withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e951bbe1dcbf7..03ea0c8c77682 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -511,9 +511,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create external table") { withTempPath { tempPath => withTable("savedJsonTable", "createdJsonTable") { - val df = read.json(sparkContext.parallelize((1 to 10).map { i => + val df = read.json((1 to 10).map { i => s"""{ "a": $i, "b": "str$i" }""" - })) + }.toDS()) withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala deleted file mode 100644 index 91ff711445e82..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} - -class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - test("makeCopy and toJSON should work") { - val table = CatalogTable( - identifier = TableIdentifier("test", Some("db")), - tableType = CatalogTableType.VIEW, - storage = CatalogStorageFormat.empty, - schema = StructType(StructField("a", IntegerType, true) :: Nil)) - val relation = MetastoreRelation("db", "test")(table, null) - - // No exception should be thrown - relation.makeCopy(Array("db", "test")) - // No exception should be thrown - relation.toJSON - } - - test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { - withTable("bar") { - withTempView("foo") { - sql("select 0 as id").createOrReplaceTempView("foo") - // If we optimize the query in CTAS more than once, the following saveAsTable will fail - // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` - sql("CREATE TABLE bar AS SELECT * FROM foo group by id") - checkAnswer(spark.table("bar"), Row(0) :: Nil) - val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) - assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") - } - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index b792a168a4f9b..50506197b3138 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -411,4 +411,15 @@ class PartitionedTablePerfStatsSuite } } } + + test("resolveRelation for a FileFormat DataSource without userSchema scan filesystem only once") { + withTempDir { dir => + import spark.implicits._ + Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) + HiveCatalogMetrics.reset() + spark.read.parquet(dir.getAbsolutePath) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e2fcd2fd41fa1..962998ea6fb68 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogStatistics +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -33,52 +33,46 @@ import org.apache.spark.sql.types._ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { - test("MetastoreRelations fallback to HDFS for size estimation") { - val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled - try { - withTempDir { tempDir => - - // EXTERNAL OpenCSVSerde table pointing to LOCATION - - val file1 = new File(tempDir + "/data1") - val writer1 = new PrintWriter(file1) - writer1.write("1,2") - writer1.close() - - val file2 = new File(tempDir + "/data2") - val writer2 = new PrintWriter(file2) - writer2.write("1,2") - writer2.close() - - sql( - s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) - ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' - WITH SERDEPROPERTIES ( - \"separatorChar\" = \",\", - \"quoteChar\" = \"\\\"\", - \"escapeChar\" = \"\\\\\") - LOCATION '${tempDir.toURI}' - """) - - spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, true) - - val relation = spark.table("csv_table").queryExecution.analyzed.children.head - .asInstanceOf[MetastoreRelation] - - val properties = relation.hiveQlTable.getParameters - assert(properties.get("totalSize").toLong <= 0, "external table totalSize must be <= 0") - assert(properties.get("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - - val sizeInBytes = relation.stats(conf).sizeInBytes - assert(sizeInBytes === BigInt(file1.length() + file2.length())) + test("Hive serde tables should fallback to HDFS for size estimation") { + withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") { + withTable("csv_table") { + withTempDir { tempDir => + // EXTERNAL OpenCSVSerde table pointing to LOCATION + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s""" + |CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES ( + |\"separatorChar\" = \",\", + |\"quoteChar\" = \"\\\"\", + |\"escapeChar\" = \"\\\\\") + |LOCATION '${tempDir.toURI}'""".stripMargin) + + val relation = spark.table("csv_table").queryExecution.analyzed.children.head + .asInstanceOf[CatalogRelation] + + val properties = relation.tableMeta.properties + assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") + assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") + + val sizeInBytes = relation.stats(conf).sizeInBytes + assert(sizeInBytes === BigInt(file1.length() + file2.length())) + } } - } finally { - spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, enableFallBackToHdfsForStats) - sql("DROP TABLE csv_table ") } } - test("analyze MetastoreRelations") { + test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes @@ -152,9 +146,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } private def checkTableStats( - stats: Option[CatalogStatistics], + tableName: String, hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Unit = { + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { assert(stats.isDefined) assert(stats.get.sizeInBytes > 0) @@ -162,26 +158,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } else { assert(stats.isEmpty) } - } - private def checkTableStats( - tableName: String, - isDataSourceTable: Boolean, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val df = sql(s"SELECT * FROM $tableName") - val stats = df.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats - case rel: LogicalRelation => - checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head + stats } test("test table-level statistics for hive tables created in HiveExternalCatalog") { @@ -192,25 +170,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") checkTableStats( textTable, - isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") checkTableStats( textTable, - isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats1 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats2 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats2 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -221,25 +197,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats1 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // when the total size is not changed, the old row count is kept - val fetchedStats2 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats2 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // update total size and remove the old and invalid row count - val fetchedStats3 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats3 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) } } - test("test statistics of LogicalRelation converted from MetastoreRelation") { + test("test statistics of LogicalRelation converted from Hive serde tables") { val parquetTable = "parquetTable" val orcTable = "orcTable" withTable(parquetTable, orcTable) { @@ -251,21 +227,14 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(500)) + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkTableStats( - orcTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(orcTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats( - orcTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) + checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } } } @@ -385,27 +354,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Add a filter to avoid creating too many partitions sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats1 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats2 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(20)) + val fetchedStats3 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(20)) assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -426,11 +391,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) dfNoCols.write.format("json").saveAsTable(table_no_cols) sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkTableStats( - table_no_cols, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(10)) + checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) } } @@ -478,10 +439,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(statsAfterUpdate.rowCount == Some(2)) } - test("estimates the size of a test MetastoreRelation") { + test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") - val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => - mr.stats(conf).sizeInBytes + val sizes = df.queryExecution.analyzed.collect { + case relation: CatalogRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -533,7 +494,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto after() } - /** Tests for MetastoreRelation */ + /** Tests for Hive serde tables */ val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238")) mkTest( @@ -541,7 +502,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto () => (), metastoreQuery, metastoreAnswer, - implicitly[ClassTag[MetastoreRelation]] + implicitly[ClassTag[CatalogRelation]] ) } @@ -555,9 +516,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass - .isAssignableFrom(r.getClass) => - r.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index 591a968c82847..e85ea5a59427d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -35,22 +35,26 @@ private[client] class HiveClientBuilder { Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) } - private def buildConf() = { + private def buildConf(extraConf: Map[String, String]) = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() metastorePath.delete() - Map( + extraConf ++ Map( "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", "hive.metastore.warehouse.dir" -> warehousePath.toString) } - def buildClient(version: String, hadoopConf: Configuration): HiveClient = { + // for testing only + def buildClient( + version: String, + hadoopConf: Configuration, + extraConf: Map[String, String] = Map.empty): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = sparkConf, hadoopConf = hadoopConf, - config = buildConf(), + config = buildConf(extraConf), ivyPath = ivyPath).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6feb277ca88e9..d61d10bf869e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -88,7 +88,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") private var client: HiveClient = null @@ -98,7 +98,12 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - client = buildClient(version, hadoopConf) + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 + // For details, see the JIRA HIVE-6113 + if (version == "2.0") { + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + } + client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f3151d52f20ac..536ca8fd9d45d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -385,7 +385,7 @@ abstract class HiveComparisonTest // also print out the query plans and results for those. val computedTablesMessages: String = try { val tablesRead = new TestHiveQueryExecution(query).executedPlan.collect { - case ts: HiveTableScanExec => ts.relation.tableName + case ts: HiveTableScanExec => ts.relation.tableMeta.identifier }.toSet TestHive.reset() @@ -393,7 +393,7 @@ abstract class HiveComparisonTest executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { case (q, e) => e.analyzed.collect { - case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + case i: InsertIntoHiveTable if tablesRead contains i.table.identifier => (q, e, i) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 792ac1e259494..81ae5b7bdb672 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1587,4 +1587,103 @@ class HiveDDLSuite } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == dir.getAbsolutePath) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == dir.getAbsolutePath) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + + test(s"CTAS for external hive table with a $tcName location") { + withTable("t", "t1") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING hive + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING hive + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index cfca1d79836b2..8a37bc3665d32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ test("show cost in explain command") { // Only has sizeInBytes before ANALYZE command @@ -92,8 +93,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempView("jt") { - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") val outputs = sql( s""" |EXPLAIN EXTENDED diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b2f19d7753956..ce92fbf349420 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -31,15 +31,15 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { - read.json(sparkContext.makeRDD( - """{"a": [{"a": {"a": 1}}]}""" :: Nil)).createOrReplaceTempView("nested") + read.json(Seq("""{"a": [{"a": {"a": 1}}]}""").toDS()) + .createOrReplaceTempView("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - read.json(sparkContext.makeRDD( - """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).createOrReplaceTempView("nested") + read.json(Seq("""{"a": [{"b": 1, "B": 2}]}""").toDS()) + .createOrReplaceTempView("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error val exception = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 5c460d25f3723..90e037e292790 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -95,8 +94,7 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH private def checkNumScannedPartitions(stmt: String, expectedNumParts: Int): Unit = { val plan = sql(stmt).queryExecution.sparkPlan val numPartitions = plan.collectFirst { - case p: HiveTableScanExec => - p.relation.getHiveQlPartitions(p.partitionPruningPred).length + case p: HiveTableScanExec => p.rawPartitions.length }.getOrElse(0) assert(numPartitions == expectedNumParts) } @@ -170,11 +168,11 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH s""" |SELECT * FROM $table """.stripMargin).queryExecution.sparkPlan - val relation = plan.collectFirst { - case p: HiveTableScanExec => p.relation + val scan = plan.collectFirst { + case p: HiveTableScanExec => p }.get - val tableCols = relation.hiveQlTable.getCols - relation.getHiveQlPartitions().foreach(p => assert(p.getCols.size == tableCols.size)) + val numDataCols = scan.relation.dataCols.length + scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 24df73b40ea0e..d535bef4cc787 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -153,8 +153,8 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScanExec(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.catalogTable.partitionColumnNames.nonEmpty) { - p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + val partValues = if (relation.isPartitioned) { + p.prunePartitions(p.rawPartitions).map(_.getValues) } else { Seq.empty } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index faed8b504649f..ef2d451e6b2d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,12 +28,12 @@ import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -526,7 +526,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case LogicalRelation(r: HadoopFsRelation, _, _) => if (!isDataSourceTable) { fail( - s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + s"${classOf[CatalogRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") } userSpecifiedLocation match { @@ -536,15 +536,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } assert(catalogTable.provider.get === format) - case r: MetastoreRelation => + case r: CatalogRelation => if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + - s"${classOf[MetastoreRelation].getCanonicalName}.") + s"${classOf[CatalogRelation].getCanonicalName}.") } userSpecifiedLocation match { case Some(location) => - assert(r.catalogTable.location === location) + assert(r.tableMeta.location === location) case None => // OK. } // Also make sure that the format and serde are as desired. @@ -973,30 +973,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-4296 Grouping field with Hive UDF as sub expression") { - val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - read.json(rdd).createOrReplaceTempView("data") + val ds = Seq("""{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - read.json(rdd).createOrReplaceTempView("data") + read.json(ds).createOrReplaceTempView("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") } test("resolve udtf in projection #1") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } test("resolve udtf in projection #2") { - val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { @@ -1010,8 +1010,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { - val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - read.json(rdd).createOrReplaceTempView("data") + val ds = Seq("""{"a": "1", "b":"1"}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer( sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) @@ -1024,13 +1024,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: MetastoreRelation, _) => // OK + case SubqueryAlias(_, r: CatalogRelation, _) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } @@ -1262,9 +1262,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-9371: fix the support for special chars in column names for hive context") { - read.json(sparkContext.makeRDD( - """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) - .createOrReplaceTempView("t") + val ds = Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS() + read.json(ds).createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } @@ -2044,4 +2043,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("SELECT * FROM foo group by id").toDF().write.format("hive").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 9fa1fb931d763..38a5477796a4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -26,8 +26,9 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} -import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf @@ -473,7 +474,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } else { queryExecution.analyzed.collectFirst { - case _: MetastoreRelation => () + case _: CatalogRelation => () }.getOrElse { fail(s"Expecting no conversion from orc to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1a1b2571b67b1..3512c4a890313 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,8 +21,8 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.execution.DataSourceScanExec -import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.execution.HiveTableScanExec import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -806,7 +806,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } else { queryExecution.analyzed.collectFirst { - case _: MetastoreRelation => + case _: CatalogRelation => }.getOrElse { fail(s"Expecting no conversion from parquet to data sources, " + s"but got:\n$queryExecution") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index cb8ed83e5a49d..b1367b8f2aed2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -145,8 +145,8 @@ private void testOperation( List>> expectedStateSnapshots) { int numBatches = expectedOutputs.size(); JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); - JavaMapWithStateDStream mapWithStateDStream = - JavaPairDStream.fromJavaDStream(inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec); + JavaMapWithStateDStream mapWithStateDStream = JavaPairDStream.fromJavaDStream( + inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec); List> collectedOutputs = Collections.synchronizedList(new ArrayList>()); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 9948a4074cdc7..80513de4ee117 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -20,10 +20,13 @@ import java.io.Serializable; import java.util.*; +import org.apache.spark.api.java.function.Function3; +import org.apache.spark.api.java.function.Function4; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.JavaTestUtils; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.State; import org.apache.spark.streaming.StateSpec; import org.apache.spark.streaming.Time; import scala.Tuple2; @@ -142,8 +145,8 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow((x, y) -> x + y, - (x, y) -> x - y, new Duration(2000), new Duration(1000)); + JavaDStream reducedWindowed = stream.reduceByWindow( + (x, y) -> x + y, (x, y) -> x - y, new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -850,36 +853,44 @@ public void testMapWithStateAPI() { JavaPairRDD initialRDD = null; JavaPairDStream wordsDstream = null; + Function4, State, Optional> mapFn = + (time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }; + JavaMapWithStateDStream stateDstream = - wordsDstream.mapWithState( - StateSpec.function((time, key, value, state) -> { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return Optional.of(2.0); - }).initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); + wordsDstream.mapWithState( + StateSpec.function(mapFn) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + Function3, State, Double> mapFn2 = + (key, value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }; + JavaMapWithStateDStream stateDstream2 = - wordsDstream.mapWithState( - StateSpec.function((key, value, state) -> { - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return 2.0; - }).initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); + wordsDstream.mapWithState( + StateSpec.function(mapFn2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); JavaPairDStream mappedDStream = stateDstream2.stateSnapshots(); } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index b966cbdca076d..96f8d9593d630 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -29,7 +29,6 @@ import org.apache.spark.streaming.Seconds; import org.apache.spark.streaming.StreamingContextState; import org.apache.spark.streaming.StreamingContextSuite; -import org.apache.spark.streaming.Time; import scala.Tuple2; import org.apache.hadoop.conf.Configuration; @@ -608,7 +607,8 @@ public void testFlatMap() { Arrays.asList("a","t","h","l","e","t","i","c","s")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(x -> Arrays.asList(x.split("(?!^)")).iterator()); + JavaDStream flatMapped = + stream.flatMap(x -> Arrays.asList(x.split("(?!^)")).iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1314,7 +1314,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2);