diff --git a/.travis.yml b/.travis.yml index d94872db6437..d7e9f8c0290e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ dist: trusty # 2. Choose language and target JDKs for parallel builds. language: java jdk: - - oraclejdk7 - oraclejdk8 # 3. Setup cache directory for SBT and Maven. diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 1afcbfcabe85..cb2eebb9ffe6 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -38,6 +38,6 @@ To run the SparkR unit tests on Windows, the following steps are required —ass ``` R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cf331bab47c6..e33d0d8e29d4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -280,7 +280,7 @@ setMethod("dtypes", #' Column Names of SparkDataFrame #' -#' Return all column names as a list. +#' Return a vector of column names. #' #' @param x a SparkDataFrame. #' @@ -338,7 +338,7 @@ setMethod("colnames", }) #' @param value a character vector. Must have the same length as the number -#' of columns in the SparkDataFrame. +#' of columns to be renamed. #' @rdname columns #' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- @@ -1804,6 +1804,10 @@ setClassUnion("numericOrcharacter", c("numeric", "character")) #' @note [[ since 1.4.0 setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] @@ -1817,6 +1821,10 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), #' @note [[<- since 2.1.1 setMethod("[[<-", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i, value) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e771a057e244..8354f705f6de 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -332,8 +332,10 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ), returning the result as a SparkDataFrame +#' Loads a JSON file, returning the result as a SparkDataFrame +#' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. @@ -346,6 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' df <- read.json(path, wholeFile = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -778,6 +781,7 @@ dropTempView <- function(viewName) { #' @return SparkDataFrame #' @rdname read.df #' @name read.df +#' @seealso \link{read.json} #' @export #' @examples #'\dontrun{ @@ -785,7 +789,7 @@ dropTempView <- function(viewName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 11940d356039..647cbbdd825e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa10..4db9cc30fb0c 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -75,9 +75,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.svmLinear(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.svmLinear(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -217,9 +220,9 @@ function(object, path, overwrite = FALSE) { #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.logit(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -236,8 +239,7 @@ function(object, path, overwrite = FALSE) { #' #' # multinomial logistic regression #' -#' df <- createDataFrame(iris) -#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' #' } @@ -245,11 +247,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +261,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 8823f9077596..0ebdb5a27308 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -72,8 +72,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.bisectingKmeans(df, Class ~ Survived, k = 4) #' summary(model) #' #' # get fitted result from a bisecting k-means model @@ -82,7 +83,7 @@ setClass("LDAModel", representation(jobj = "jobj")) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -338,14 +339,14 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.kmeans(df, Class ~ Survived, k = 4, initMode = "random") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 96ee220bc411..648d363f1a25 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -68,14 +68,14 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Freq", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) + tol, as.integer(maxIter), weightCol, regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -135,9 +137,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- glm(Freq ~ Sex + Age, df, family = "gaussian") #' summary(model) #' } #' @note glm since 1.5.0 @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) + weightCol) new("IsotonicRegressionModel", jobj = jobj) }) @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { + function(data, formula, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) + "fit", formula, data@sdf, as.integer(aggregationDepth)) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 0d53fad06180..40a806c41bad 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -143,14 +143,15 @@ print.summary.treeEnsemble <- function(x) { #' #' # fit a Gradient Boosted Tree Classification Model #' # label must be binary - Only binary classification is supported for GBT. -#' df <- createDataFrame(iris[iris$Species != "virginica", ]) -#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.gbt(df, Survived ~ Age + Freq, "classification") #' #' # numeric label is also supported -#' iris2 <- iris[iris$Species != "virginica", ] -#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) -#' df <- createDataFrame(iris2) -#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' t2 <- as.data.frame(Titanic) +#' t2$NumericGender <- ifelse(t2$Sex == "Male", 0, 1) +#' df <- createDataFrame(t2) +#' model <- spark.gbt(df, NumericGender ~ ., type = "classification") #' } #' @note spark.gbt since 2.1.0 setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), @@ -351,8 +352,9 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' summary(savedModel) #' #' # fit a Random Forest Classification Model -#' df <- createDataFrame(iris) -#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.randomForest(df, Survived ~ Freq + Age, "classification") #' } #' @note spark.randomForest since 2.1.0 setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 620f528f2e6c..459254d271a5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -211,7 +211,15 @@ test_that("spark.logit", { df <- createDataFrame(data) model <- spark.logit(df, label ~ feature) prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) }) test_that("spark.mlp", { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index a7259f362ebe..1dd8c5ce6cb3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -898,6 +898,12 @@ test_that("names() colnames() set the column names", { expect_equal(names(z)[3], "c") names(z)[3] <- "c2" expect_equal(names(z)[3], "c2") + + # Test subset assignment + colnames(df)[1] <- "col5" + expect_equal(colnames(df)[1], "col5") + names(df)[2] <- "col6" + expect_equal(names(df)[2], "col6") }) test_that("head() and first() return the correct data", { @@ -1015,6 +1021,18 @@ test_that("select operators", { expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") + expect_warning(df[[1:2]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[1:2]]), "Column") + expect_warning(df[[c("name", "age")]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[c("name", "age")]]), "Column") + + expect_warning(df[[1:2]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_warning(df[[c("name", "age")]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_is(df[, 1, drop = F], "SparkDataFrame") expect_equal(columns(df[, 1, drop = F]), c("name")) expect_equal(columns(df[, "age", drop = F]), c("age")) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index bc8bc3c26c11..43c255cff302 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -565,11 +565,10 @@ We use a simple example to demonstrate `spark.logit` usage. In general, there ar and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. Binomial logistic regression -```{r, warning=FALSE} -df <- createDataFrame(iris) -# Create a DataFrame containing two classes -training <- df[df$Species %in% c("versicolor", "virginica"), ] -model <- spark.logit(training, Species ~ ., regParam = 0.00042) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.logit(training, Survived ~ ., regParam = 0.04741301) summary(model) ``` @@ -579,10 +578,11 @@ fitted <- predict(model, training) ``` Multinomial logistic regression against three classes -```{r, warning=FALSE} -df <- createDataFrame(iris) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. -model <- spark.logit(df, Species ~ ., regParam = 0.056) +model <- spark.logit(training, Class ~ ., regParam = 0.07815179) summary(model) ``` @@ -609,11 +609,12 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu `spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. -We use iris data set to show how to use `spark.mlp` in classification. -```{r, warning=FALSE} -df <- createDataFrame(iris) +We use Titanic data set to show how to use `spark.mlp` in classification. +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. @@ -630,7 +631,7 @@ options(ops) ``` ```{r} # make predictions use the fitted model -predictions <- predict(model, df) +predictions <- predict(model, training) head(select(predictions, predictions$prediction)) ``` @@ -769,12 +770,13 @@ predictions <- predict(rfModel, df) `spark.bisectingKmeans` is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) using a divisive (or "top-down") approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. -```{r, warning=FALSE} -df <- createDataFrame(iris) -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) summary(model) -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) ``` #### Gaussian Mixture Model @@ -912,9 +914,10 @@ testSummary ### Model Persistence The following example shows how to save/load an ML model by SparkR. -```{r, warning=FALSE} -irisDF <- createDataFrame(iris) -gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +gaussianGLM <- spark.glm(training, Freq ~ Sex + Age, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") @@ -925,7 +928,7 @@ gaussianGLM2 <- read.ml(modelPath) summary(gaussianGLM2) # Check model prediction -gaussianPredictions <- predict(gaussianGLM2, irisDF) +gaussianPredictions <- predict(gaussianGLM2, training) head(gaussianPredictions) unlink(modelPath) diff --git a/R/run-tests.sh b/R/run-tests.sh index 5e4dafaf76f3..742a2c5ed76d 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/appveyor.yml b/appveyor.yml index 6bc66c0ea54d..5adf1b4bedb4 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -46,7 +46,7 @@ build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 4477c9a935f2..09fc80d12d51 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -26,7 +26,6 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; -import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c7ea9085eba6..73577437ac50 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; /** - * Simulates Hive's hashing function at + * Simulates Hive's hashing function from Hive v1.2.1 * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() */ public class HiveHasher { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 87b9e8eb445a..10a7cb1d0665 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -153,7 +153,8 @@ public void writeTo(ByteBuffer buffer) { * * Unlike getBytes this will not create a copy the array if this is a slice. */ - public @Nonnull ByteBuffer getByteBuffer() { + @Nonnull + public ByteBuffer getByteBuffer() { if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { final byte[] bytes = (byte[]) base; diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 9fe97b4d9c20..140c52fd12f9 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -30,116 +30,117 @@ */ public class SparkFirehoseListener implements SparkListenerInterface { - public void onEvent(SparkListenerEvent event) { } - - @Override - public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { - onEvent(stageCompleted); - } - - @Override - public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { - onEvent(stageSubmitted); - } - - @Override - public final void onTaskStart(SparkListenerTaskStart taskStart) { - onEvent(taskStart); - } - - @Override - public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { - onEvent(taskGettingResult); - } - - @Override - public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { - onEvent(taskEnd); - } - - @Override - public final void onJobStart(SparkListenerJobStart jobStart) { - onEvent(jobStart); - } - - @Override - public final void onJobEnd(SparkListenerJobEnd jobEnd) { - onEvent(jobEnd); - } - - @Override - public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { - onEvent(environmentUpdate); - } - - @Override - public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { - onEvent(blockManagerAdded); - } - - @Override - public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { - onEvent(blockManagerRemoved); - } - - @Override - public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { - onEvent(unpersistRDD); - } - - @Override - public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { - onEvent(applicationStart); - } - - @Override - public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { - onEvent(applicationEnd); - } - - @Override - public final void onExecutorMetricsUpdate( - SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { - onEvent(executorMetricsUpdate); - } - - @Override - public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { - onEvent(executorAdded); - } - - @Override - public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { - onEvent(executorRemoved); - } - - @Override - public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) { - onEvent(executorBlacklisted); - } - - @Override - public final void onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted executorUnblacklisted) { - onEvent(executorUnblacklisted); - } - - @Override - public final void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) { - onEvent(nodeBlacklisted); - } - - @Override - public final void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) { - onEvent(nodeUnblacklisted); - } - - @Override - public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { - onEvent(blockUpdated); - } - - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } + + @Override + public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) { + onEvent(executorBlacklisted); + } + + @Override + public final void onExecutorUnblacklisted( + SparkListenerExecutorUnblacklisted executorUnblacklisted) { + onEvent(executorUnblacklisted); + } + + @Override + public final void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) { + onEvent(nodeBlacklisted); + } + + @Override + public final void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) { + onEvent(nodeUnblacklisted); + } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 29aca04a3d11..f312fa2b2ddd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -161,7 +161,9 @@ private UnsafeExternalSorter( // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addTaskCompletionListener(context -> { cleanupResources(); }); + taskContext.addTaskCompletionListener(context -> { + cleanupResources(); + }); } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e4d83893e740..0e36a30c933d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -961,12 +961,11 @@ class SparkContext(config: SparkConf) extends Logging { classOf[LongWritable], classOf[BytesWritable], conf = conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() assert(bytes.length == recordLength, "Byte array does not have correct length") bytes } - data } /** @@ -1816,10 +1815,18 @@ class SparkContext(config: SparkConf) extends Logging { // A JAR file which exists only on the driver node case null | "file" => try { + val file = new File(uri.getPath) + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") + } + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") + } env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) null } // A JAR file which exists locally on every worker node diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0fd777ed1282..f0867ecb16ea 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e083911c..dc0d12878550 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ private[spark] class TaskContextImpl( @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailedException: Option[FetchFailedException] = None + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this._fetchFailedException = Option(fetchFailed) + } + + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 109104f0a537..3f912dc19151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -200,9 +200,13 @@ private[spark] object TestUtils { /** * Returns the response code from an HTTP(S) URL. */ - def httpResponseCode(url: URL, method: String = "GET"): Int = { + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) + headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } // Disable cert and host name validation for HTTPS tests. if (connection.isInstanceOf[HttpsURLConnection]) { diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0b1cec2df830..a8f732b11f6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -85,6 +85,7 @@ object PythonRunner { // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) + sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize try { val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 941e2d13fb28..f475ce87540a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -82,17 +82,20 @@ class SparkHadoopUtil extends Logging { // the behavior of the old implementation of this code, for backwards compatibility. if (conf != null) { // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && - System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - val keyId = System.getenv("AWS_ACCESS_KEY_ID") - val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") - + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + if (keyId != null && accessKey != null) { hadoopConf.set("fs.s3.awsAccessKeyId", keyId) hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) hadoopConf.set("fs.s3a.access.key", keyId) hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3a.secret.key", accessKey) + + val sessionToken = System.getenv("AWS_SESSION_TOKEN") + if (sessionToken != null) { + hadoopConf.set("fs.s3a.session.token", sessionToken) + } } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5ffdedd1658a..1e50eb663565 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -665,7 +665,8 @@ object SparkSubmit extends CommandLineUtils { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") - printStream.println(s"System properties:\n${sysProps.mkString("\n")}") + // sysProps may contain sensitive information, so redact before printing + printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index dee77343d806..0614d80b60e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -84,9 +84,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + val properties = Utils.getPropertiesFromFile(filename) + properties.foreach { case (k, v) => defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } + // Property files may contain sensitive information, so redact before printing + if (verbose) { + Utils.redact(properties).foreach { case (k, v) => + SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } } } // scalastyle:on println @@ -318,7 +324,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | |Spark properties used, including those specified through | --conf and those from the properties file $propertiesFile: - |${sparkProperties.mkString(" ", "\n ", "\n")} + |${Utils.redact(sparkProperties).mkString(" ", "\n ", "\n")} """.stripMargin } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d762f1112551..790c1ae94247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer @@ -52,7 +53,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -78,7 +80,7 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool @@ -148,6 +150,8 @@ private[spark] class Executor( startDriverHeartbeater() + private[executor] def numRunningTasks: Int = runningTasks.size() + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) @@ -340,6 +344,14 @@ private[spark] class Executor( } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime @@ -400,8 +412,17 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -453,13 +474,17 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } } /** diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index 013cd1c1bc03..c7f2847731fc 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. */ -private[spark] trait Logging { +trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 2c1b5636888a..22e26799138b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -113,11 +113,11 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val taskAttemptId = new TaskAttemptID(taskId, 0) // Set up the configuration object - jobContext.getConfiguration.set("mapred.job.id", jobId.toString) - jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) - jobContext.getConfiguration.setInt("mapred.task.partition", 0) + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 1e0a1e605cfb..659ad5d0bad8 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -79,7 +79,7 @@ object SparkHadoopMapReduceWriter extends Logging { val committer = FileCommitProtocol.instantiate( className = classOf[HadoopMapReduceCommitProtocol].getName, jobId = stageId.toString, - outputPath = conf.value.get("mapred.output.dir"), + outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] committer.setupJob(jobContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 5fa6a7ed315f..4bf8ecc38354 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -365,11 +365,11 @@ private[spark] object HadoopRDD extends Logging { val jobID = new JobID(jobTrackerId, jobId) val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) - conf.set("mapred.tip.id", taId.getTaskID.toString) - conf.set("mapred.task.id", taId.toString) - conf.setBoolean("mapred.task.is.map", true) - conf.setInt("mapred.task.partition", splitId) - conf.set("mapred.job.id", jobID.toString) + conf.set("mapreduce.task.id", taId.getTaskID.toString) + conf.set("mapreduce.task.attempt.id", taId.toString) + conf.setBoolean("mapreduce.task.ismap", true) + conf.setInt("mapreduce.task.partition", splitId) + conf.set("mapreduce.job.id", jobID.toString) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 567a3183e224..52ce03ff8cde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -998,7 +998,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) val jobConfiguration = job.getConfiguration - jobConfiguration.set("mapred.output.dir", path) + jobConfiguration.set("mapreduce.output.fileoutputformat.outputdir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1039,10 +1039,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) - hadoopConf.set("mapred.output.compress", "true") + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") hadoopConf.setMapOutputCompressorClass(c) - hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) - hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.codec", c.getCanonicalName) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.type", + CompressionType.BLOCK.toString) } // Use configured output committer if already set diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 08d220b40b6f..83d87b548a43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -48,25 +48,29 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private type StageId = Int private type PartitionId = Int private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + private case class StageState(numPartitions: Int) { + val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) + val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + } /** - * Map from active stages's id => partition id => task attempt with exclusive lock on committing - * output for that partition. + * Map from active stages's id => authorized task attempts for each partition id, which hold an + * exclusive lock on committing task output for that partition, as well as any known failed + * attempts in the stage. * * Entries are added to the top-level map when stages start and are removed they finish * (either successfully or unsuccessfully). * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() + private val stageStates = mutable.Map[StageId, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. */ def isEmpty: Boolean = { - authorizedCommittersByStage.isEmpty + stageStates.isEmpty } /** @@ -105,19 +109,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart( - stage: StageId, - maxPartitionId: Int): Unit = { - val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) - java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) - synchronized { - authorizedCommittersByStage(stage) = arr - } + private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { + stageStates(stage) = new StageState(maxPartitionId + 1) } // Called by DAGScheduler private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { - authorizedCommittersByStage.remove(stage) + stageStates.remove(stage) } // Called by DAGScheduler @@ -126,7 +124,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) partition: PartitionId, attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { - val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") return }) @@ -137,10 +135,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters(partition) == attemptNumber) { + // Mark the attempt as failed to blacklist from future commit protocol + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber + if (stageState.authorizedCommitters(partition) == attemptNumber) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -149,7 +149,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) if (isDriver) { coordinatorRef.foreach(_ send StopCoordinator) coordinatorRef = None - authorizedCommittersByStage.clear() + stageStates.clear() } } @@ -158,13 +158,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) stage: StageId, partition: PartitionId, attemptNumber: TaskAttemptNumber): Boolean = synchronized { - authorizedCommittersByStage.get(stage) match { - case Some(authorizedCommitters) => - authorizedCommitters(partition) match { + stageStates.get(stage) match { + case Some(state) if attemptFailed(state, partition, attemptNumber) => + logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + + s" partition=$partition as task attempt $attemptNumber has already failed.") + false + case Some(state) => + state.authorizedCommitters(partition) match { case NO_AUTHORIZED_COMMITTER => logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + s"partition=$partition") - authorizedCommitters(partition) = attemptNumber + state.authorizedCommitters(partition) = attemptNumber true case existingCommitter => // Coordinator should be idempotent when receiving AskPermissionToCommit. @@ -181,11 +185,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } } case None => - logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + - s"partition $partition to commit") + logDebug(s"Stage $stage has completed, so not allowing" + + s" attempt number $attemptNumber of partition $partition to commit") false } } + + private def attemptFailed( + stageState: StageState, + partition: PartitionId, + attempt: TaskAttemptNumber): Boolean = synchronized { + stageState.failures.get(partition).exists(_.contains(attempt)) + } } private[spark] object OutputCommitCoordinator { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7b726d5659e9..70213722aae4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,19 +17,14 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util._ /** @@ -137,6 +132,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still queried + // directly in the TaskRunner to check for FetchFailedExceptions. TaskContext.unset() } } @@ -156,7 +153,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b25513bea05..19ebaf817e24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -874,7 +874,8 @@ private[spark] class TaskSetManager( // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled + && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -906,8 +907,6 @@ private[spark] class TaskSetManager( * Check for tasks to be speculated and return true if there are any. This is called periodically * by the TaskScheduler. * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a @@ -927,7 +926,8 @@ private[spark] class TaskSetManager( // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { + for (tid <- runningTasksSet) { + val info = taskInfos(tid) val index = info.index if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196ce..265a8acfa8d6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,6 +50,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 17bc04303fa8..00f918c09c66 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1 import java.util.zip.ZipOutputStream import javax.servlet.ServletContext +import javax.servlet.http.HttpServletRequest import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class ApiRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJobs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs") def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs/{jobId: \\d+}") def getJob(@PathParam("appId") appId: String): OneJobResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneJobResource(ui) } } @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJob( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneJobResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneJobResource(ui) } } @Path("applications/{appId}/executors") def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ExecutorListResource(ui) } } @Path("applications/{appId}/allexecutors") def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllExecutorListResource(ui) } } @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ExecutorListResource(ui) } } @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getAllExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllExecutorListResource(ui) } } - @Path("applications/{appId}/stages") def getStages(@PathParam("appId") appId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllStagesResource(ui) } } @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStages( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") def getStage(@PathParam("appId") appId: String): OneStageResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneStageResource(ui) } } @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStage( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneStageResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneStageResource(ui) } } @Path("applications/{appId}/storage/rdd") def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllRDDResource(ui) } } @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdds( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllRDDResource(ui) } } @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneRDDResource(ui) } } @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdd( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneRDDResource(ui) } } @@ -199,6 +199,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { new VersionResource(uiRoot) } + @Path("applications/{appId}/environment") + def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, None) { ui => + new ApplicationEnvironmentResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/environment") + def getEnvironment( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, Some(attemptId)) { ui => + new ApplicationEnvironmentResource(ui) + } + } } private[spark] object ApiRootResource { @@ -234,19 +249,6 @@ private[spark] trait UIRoot { .status(Response.Status.SERVICE_UNAVAILABLE) .build() } - - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - getSparkUI(appKey) match { - case Some(ui) => - f(ui) - case None => throw new NotFoundException("no such app: " + appId) - } - } def securityManager: SecurityManager } @@ -263,13 +265,38 @@ private[v1] object UIRootFromServletContext { } } -private[v1] trait UIRootFromServletContext { +private[v1] trait ApiRequestContext { + @Context + protected var servletContext: ServletContext = _ + @Context - var servletContext: ServletContext = _ + protected var httpRequest: HttpServletRequest = _ def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) + + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + uiRoot.getSparkUI(appKey) match { + case Some(ui) => + val user = httpRequest.getRemoteUser() + if (!ui.securityManager.checkUIViewPermissions(user)) { + throw new ForbiddenException(raw"""user "$user" is not authorized""") + } + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } + } +private[v1] class ForbiddenException(msg: String) extends WebApplicationException( + Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + private[v1] class NotFoundException(msg: String) extends WebApplicationException( new NoSuchElementException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala new file mode 100644 index 000000000000..739a8aceae86 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +import org.apache.spark.ui.SparkUI + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { + + @GET + def getEnvironmentInfo(): ApplicationEnvironmentInfo = { + val listener = ui.environmentListener + listener.synchronized { + val jvmInfo = Map(listener.jvmInformation: _*) + val runtime = new RuntimeInfo( + jvmInfo("Java Version"), + jvmInfo("Java Home"), + jvmInfo("Scala Version")) + + new ApplicationEnvironmentInfo( + runtime, + listener.sparkProperties, + listener.systemProperties, + listener.classpathEntries) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index b4a991eda35f..1cd37185d660 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response import javax.ws.rs.ext.Provider @Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { +private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { override def filter(req: ContainerRequestContext): Unit = { - val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + val user = httpRequest.getRemoteUser() if (!uiRoot.securityManager.checkUIViewPermissions(user)) { req.abortWith( Response .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user"is not authorized""") + .entity(raw"""user "$user" is not authorized""") .build() ) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index c509398db1ec..5b9227350eda 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -252,3 +252,14 @@ class AccumulableInfo private[spark]( class VersionInfo private[spark]( val spark: String) + +class ApplicationEnvironmentInfo private[spark] ( + val runtime: RuntimeInfo, + val sparkProperties: Seq[(String, String)], + val systemProperties: Seq[(String, String)], + val classpathEntries: Seq[(String, String)]) + +class RuntimeInfo private[spark]( + val javaVersion: String, + val javaHome: String, + val scalaVersion: String) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6946a98cdda6..45b73380806d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1159,6 +1159,34 @@ private[spark] class BlockManager( } } + /** + * Called for pro-active replenishment of blocks lost due to executor failures + * + * @param blockId blockId being replicate + * @param existingReplicas existing block managers that have a replica + * @param maxReplicas maximum replicas needed + */ + def replicateBlock( + blockId: BlockId, + existingReplicas: Set[BlockManagerId], + maxReplicas: Int): Unit = { + logInfo(s"Pro-actively replicating $blockId") + blockInfoManager.lockForReading(blockId).foreach { info => + val data = doGetLocalBytes(blockId, info) + val storageLevel = StorageLevel( + useDisk = info.level.useDisk, + useMemory = info.level.useMemory, + useOffHeap = info.level.useOffHeap, + deserialized = info.level.deserialized, + replication = maxReplicas) + try { + replicate(blockId, data, storageLevel, info.classTag, existingReplicas) + } finally { + releaseLock(blockId) + } + } + } + /** * Replicate block to another node. Note that this is a blocking call that returns after * the block has been replicated. @@ -1167,7 +1195,8 @@ private[spark] class BlockManager( blockId: BlockId, data: ChunkedByteBuffer, level: StorageLevel, - classTag: ClassTag[_]): Unit = { + classTag: ClassTag[_], + existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) val tLevel = StorageLevel( @@ -1181,20 +1210,22 @@ private[spark] class BlockManager( val startTime = System.nanoTime - var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId] + var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] var numFailures = 0 + val initialPeers = getPeers(false).filterNot(existingReplicas.contains(_)) + var peersForReplication = blockReplicationPolicy.prioritize( blockManagerId, - getPeers(false), - mutable.HashSet.empty, + initialPeers, + peersReplicatedTo, blockId, numPeersToReplicateTo) while(numFailures <= maxReplicationFailures && - !peersForReplication.isEmpty && - peersReplicatedTo.size != numPeersToReplicateTo) { + !peersForReplication.isEmpty && + peersReplicatedTo.size < numPeersToReplicateTo) { val peer = peersForReplication.head try { val onePeerStartTime = System.nanoTime diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 145c434a4f0c..84c04d22600a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -22,6 +22,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi @@ -65,6 +66,8 @@ class BlockManagerMasterEndpoint( mapper } + val proactivelyReplicate = conf.get("spark.storage.replication.proactive", "false").toBoolean + logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -195,17 +198,38 @@ class BlockManagerMasterEndpoint( // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId) locations -= blockManagerId + // De-register the block if none of the block managers have it. Otherwise, if pro-active + // replication is enabled, and a block is either an RDD or a test block (the latter is used + // for unit testing), we send a message to a randomly chosen executor location to replicate + // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks + // etc.) as replication doesn't make much sense in that context. if (locations.size == 0) { blockLocations.remove(blockId) + logWarning(s"No more replicas available for $blockId !") + } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) { + // As a heursitic, assume single executor failure to find out the number of replicas that + // existed before failure + val maxReplicas = locations.size + 1 + val i = (new Random(blockId.hashCode)).nextInt(locations.size) + val blockLocations = locations.toSeq + val candidateBMId = blockLocations(i) + blockManagerInfo.get(candidateBMId).foreach { bm => + val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) + val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + bm.slaveEndpoint.ask[Boolean](replicateMsg) + } } } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) logInfo(s"Removing block manager $blockManagerId") + } private def removeExecutor(execId: String) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index d71acbb4cf77..0aea438e7f47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -32,6 +32,10 @@ private[spark] object BlockManagerMessages { // blocks that the master knows about. case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + // Replicate blocks that were lost due to executor failure + case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) + extends ToBlockManagerSlave + // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index d17ddbc16257..1aaa42459df6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -74,6 +74,10 @@ class BlockManagerSlaveEndpoint( case TriggerThreadDump => context.reply(Utils.getThreadDump()) + + case ReplicateBlock(blockId, replicas, maxReplicas) => + context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas)) + } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 7909821db954..bdbdba578085 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging { response.setHeader("X-Frame-Options", xFrameOptionsValue) response.getWriter.print(servletParams.extractFn(result)) } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setStatus(HttpServletResponse.SC_FORBIDDEN) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + response.sendError(HttpServletResponse.SC_FORBIDDEN, "User is not authorized to access this page.") } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 79fc2e94599c..fa5ad4e8d81e 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -52,7 +52,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. */ - final def postToAll(event: E): Unit = { + def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1e6e9a223e29..1af34e3da231 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -38,6 +39,7 @@ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} @@ -1109,26 +1111,39 @@ private[spark] object Utils extends Logging { /** * Convert a quantity in bytes to a human-readable string such as "4.0 MB". */ - def bytesToString(size: Long): String = { + def bytesToString(size: Long): String = bytesToString(BigInt(size)) + + def bytesToString(size: BigInt): String = { + val EB = 1L << 60 + val PB = 1L << 50 val TB = 1L << 40 val GB = 1L << 30 val MB = 1L << 20 val KB = 1L << 10 - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") + if (size >= BigInt(1L << 11) * EB) { + // The number is too large, show it in scientific notation. + BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B" + } else { + val (value, unit) = { + if (size >= 2 * EB) { + (BigDecimal(size) / EB, "EB") + } else if (size >= 2 * PB) { + (BigDecimal(size) / PB, "PB") + } else if (size >= 2 * TB) { + (BigDecimal(size) / TB, "TB") + } else if (size >= 2 * GB) { + (BigDecimal(size) / GB, "GB") + } else if (size >= 2 * MB) { + (BigDecimal(size) / MB, "MB") + } else if (size >= 2 * KB) { + (BigDecimal(size) / KB, "KB") + } else { + (BigDecimal(size), "B") + } } + "%.1f %s".formatLocal(Locale.US, value, unit) } - "%.1f %s".formatLocal(Locale.US, value, unit) } /** @@ -1989,7 +2004,7 @@ private[spark] object Utils extends Logging { if (paths == null || paths.trim.isEmpty) { "" } else { - paths.split(",").map { p => Utils.resolveURI(p) }.mkString(",") + paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",") } } @@ -2210,17 +2225,32 @@ private[spark] object Utils extends Logging { } catch { case e: Exception if isBindCollision(e) => if (offset >= maxRetries) { - val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + - s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + - s"the appropriate port for the service$serviceString (for example spark.ui.port " + - s"for SparkUI) to an available port or increasing spark.port.maxRetries." + val exceptionMessage = if (startPort == 0) { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (on a random free port)! " + + s"Consider explicitly setting the appropriate binding address for " + + s"the service$serviceString (for example spark.driver.bindAddress " + + s"for SparkDriver) to the correct binding address." + } else { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + + s"the appropriate port for the service$serviceString (for example spark.ui.port " + + s"for SparkUI) to an available port or increasing spark.port.maxRetries." + } val exception = new BindException(exceptionMessage) // restore original stack trace exception.setStackTrace(e.getStackTrace) throw exception } - logWarning(s"Service$serviceString could not bind on port $tryPort. " + - s"Attempting port ${tryPort + 1}.") + if (startPort == 0) { + // As startPort 0 is for a random free port, it is most possibly binding address is + // not correct. + logWarning(s"Service$serviceString could not bind on a random free port. " + + "You may check whether configuring an appropriate binding address.") + } else { + logWarning(s"Service$serviceString could not bind on port $tryPort. " + + s"Attempting port ${tryPort + 1}.") + } } } // Should never happen @@ -2559,13 +2589,31 @@ private[spark] object Utils extends Logging { def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + redact(redactionPattern, kvs) + } + + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { kvs.map { kv => redactionPattern.findFirstIn(kv._1) - .map { ignore => (kv._1, REDACTION_REPLACEMENT_TEXT) } + .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } .getOrElse(kv) } } + /** + * Looks up the redaction regex from within the key value pairs and uses it to redact the rest + * of the key value pairs. No care is taken to make sure the redaction property itself is not + * redacted. So theoretically, the property itself could be configured to redact its own value + * when printing. + */ + def redact(kvs: Map[String, String]): Seq[(String, String)] = { + val redactionPattern = kvs.getOrElse( + SECRET_REDACTION_PATTERN.key, + SECRET_REDACTION_PATTERN.defaultValueString + ).r + redact(redactionPattern, kvs.toArray) + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 512149127d72..01b5fb7b4668 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -358,7 +358,7 @@ public void groupByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function, Boolean> areOdd = - x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); + x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd); assertEquals(2, oddsAndEvens.count()); @@ -528,14 +528,14 @@ public void aggregateByKey() { new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), - (a, b) -> { - a.add(b); - return a; - }, - (a, b) -> { - a.addAll(b); - return a; - }).collectAsMap(); + (a, b) -> { + a.add(b); + return a; + }, + (a, b) -> { + a.addAll(b); + return a; + }).collectAsMap(); assertEquals(3, sets.size()); assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); @@ -666,8 +666,8 @@ public void javaDoubleRDDHistoGram() { assertArrayEquals(expected_counts, histogram); // SPARK-5744 assertArrayEquals( - new long[] {0}, - sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); + new long[] {0}, + sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); } private static class DoubleComparator implements Comparator, Serializable { @@ -807,7 +807,7 @@ public void mapsFromPairsToPairs() { // Regression test for SPARK-668: JavaPairRDD swapped = pairRDD.flatMapToPair( - item -> Collections.singletonList(item.swap()).iterator()); + item -> Collections.singletonList(item.swap()).iterator()); swapped.collect(); // There was never a bug here, but it's worth testing: @@ -845,11 +845,13 @@ public void mapPartitionsWithIndex() { public void getNumPartitions(){ JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); - JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( - new Tuple2<>("a", 1), - new Tuple2<>("aa", 2), - new Tuple2<>("aaa", 3) - ), 2); + JavaPairRDD rdd3 = sc.parallelizePairs( + Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), + 2); assertEquals(3, rdd1.getNumPartitions()); assertEquals(2, rdd2.getNumPartitions()); assertEquals(2, rdd3.getNumPartitions()); @@ -977,7 +979,7 @@ public void sequenceFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, @@ -1068,11 +1070,11 @@ public void writeWithNewAPIHadoopFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, + .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); JavaPairRDD output = - sc.sequenceFile(outputDir, IntWritable.class, Text.class); + sc.sequenceFile(outputDir, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1088,11 +1090,11 @@ public void readWithNewAPIHadoopFile() throws IOException { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, Job.getInstance().getConfiguration()); + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1135,10 +1137,10 @@ public void hadoopFile() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); + SequenceFileInputFormat.class, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1154,10 +1156,11 @@ public void hadoopFileCompressed() { JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, + SequenceFileOutputFormat.class, DefaultCodec.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); + SequenceFileInputFormat.class, IntWritable.class, Text.class); assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @@ -1263,23 +1266,23 @@ public void combineByKey() { Function2 mergeValueFunction = (v1, v2) -> v1 + v2; JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); Map results = combinedRDD.collectAsMap(); ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), - JavaConverters.collectionAsScalaIterableConverter( - Collections.>emptyList()).asScala().toSeq()); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey( - createCombinerFunction, - mergeValueFunction, - mergeValueFunction, - defaultPartitioner, - false, - new KryoSerializer(new SparkConf())); + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); results = combinedRDD.collectAsMap(); assertEquals(expected, results); } @@ -1291,11 +1294,10 @@ public void mapOnPairRDD() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); assertEquals(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(0, 2), - new Tuple2<>(1, 3), - new Tuple2<>(0, 4)), rdd3.collect()); - + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @SuppressWarnings("unchecked") @@ -1312,16 +1314,18 @@ public void collectPartitions() { assertEquals(Arrays.asList(3, 4), parts[0]); assertEquals(Arrays.asList(5, 6, 7), parts[1]); - assertEquals(Arrays.asList(new Tuple2<>(1, 1), - new Tuple2<>(2, 0)), - rdd2.collectPartitions(new int[] {0})[0]); + assertEquals( + Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), + rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); - assertEquals(Arrays.asList(new Tuple2<>(5, 1), - new Tuple2<>(6, 0), - new Tuple2<>(7, 1)), - parts2[1]); + assertEquals( + Arrays.asList( + new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), + parts2[1]); } @Test @@ -1352,7 +1356,6 @@ public void countApproxDistinctByKey() { double error = Math.abs((resCount - count) / count); assertTrue(error < 0.1); } - } @Test @@ -1531,8 +1534,8 @@ public void testRegisterKryoClasses() { SparkConf conf = new SparkConf(); conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); assertEquals( - Class1.class.getName() + "," + Class2.class.getName(), - conf.get("spark.kryo.classesToRegister")); + Class1.class.getName() + "," + Class2.class.getName(), + conf.get("spark.kryo.classesToRegister")); } @Test diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 6538507d407e..5be0121db58a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.nio.ByteBuffer import java.util.zip.GZIPOutputStream import scala.io.Source @@ -30,7 +31,6 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel @@ -237,24 +237,26 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } - test("binary file input as byte array") { - sc = new SparkContext("local", "test") + private def writeBinaryData(testOutput: Array[Byte], testOutputCopies: Int): File = { val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) + val file = new FileOutputStream(outFile) val channel = file.getChannel - channel.write(bbuf) + for (i <- 0 until testOutputCopies) { + // Shift values by i so that they're different in the output + val alteredOutput = testOutput.map(b => (b + i).toByte) + channel.write(ByteBuffer.wrap(alteredOutput)) + } channel.close() file.close() + outFile + } - val inRdd = sc.binaryFiles(outFileName) - val (infile: String, indata: PortableDataStream) = inRdd.collect.head - + test("binary file input as byte array") { + sc = new SparkContext("local", "test") + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) + val (infile, indata) = inRdd.collect().head // Make sure the name and array match assert(infile.contains(outFile.toURI.getPath)) // a prefix may get added assert(indata.toArray === testOutput) @@ -262,159 +264,55 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { test("portabledatastream caching tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).cache() - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).cache() + inRdd.foreach(_._2.toArray()) // force the file to read // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream persist disk storage") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY) - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - - // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).persist(StorageLevel.DISK_ONLY) + inRdd.foreach(_._2.toArray()) // force the file to read + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream flatmap tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) val numOfCopies = 3 - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName) - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val copyRdd = mappedRdd.flatMap { - curData: (String, PortableDataStream) => - for (i <- 1 to numOfCopies) yield (i, curData._2) - } - - val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() - - // Try reading the output back as an object file + val copyRdd = inRdd.flatMap(curData => (0 until numOfCopies).map(_ => curData._2)) + val copyArr = copyRdd.collect() assert(copyArr.length == numOfCopies) - copyArr.foreach{ - cEntry: (Int, PortableDataStream) => - assert(cEntry._2.toArray === testOutput) + for (i <- copyArr.indices) { + assert(copyArr(i).toArray === testOutput) } - } test("fixed record length binary file as byte array") { - // a fixed length of 6 bytes - sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, testOutput.length) - // make sure there are enough elements + val outFile = writeBinaryData(testOutput, testOutputCopies) + val inRdd = sc.binaryRecords(outFile.getAbsolutePath, testOutput.length) assert(inRdd.count == testOutputCopies) - - // now just compare the first one - val indata: Array[Byte] = inRdd.collect.head - assert(indata === testOutput) + val inArr = inRdd.collect() + for (i <- inArr.indices) { + assert(inArr(i) === testOutput.map(b => (b + i).toByte)) + } } test ("negative binary record length should raise an exception") { - // a fixed length of 6 bytes sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, -1) - + val outFile = writeBinaryData(Array[Byte](1, 2, 3, 4, 5, 6), 1) intercept[SparkException] { - inRdd.count + sc.binaryRecords(outFile.getAbsolutePath, -1).count() } } @@ -503,7 +401,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.set("mapred.output.format.class", classOf[TextOutputFormat[String, String]].getName) - job.set("mapred.output.dir", tempDir.getPath + "/outputDataset_old") + job.set("mapreduce.output.fileoutputformat.outputdir", tempDir.getPath + "/outputDataset_old") randomRDD.saveAsHadoopDataset(job) assert(new File(tempDir.getPath + "/outputDataset_old/part-00000").exists() === true) } @@ -517,7 +415,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) val jobConfig = job.getConfiguration - jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + jobConfig.set("mapreduce.output.fileoutputformat.outputdir", + tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5a41e1c61908..f97a112ec127 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -292,6 +292,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("add jar with invalid path") { + val tmpDir = Utils.createTempDir() + val tmpJar = File.createTempFile("test", ".jar", tmpDir) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(tmpJar.getAbsolutePath) + + // Invaid jar path will only print the error log, will not add to file server. + sc.addJar("dummy.jar") + sc.addJar("") + sc.addJar(tmpDir.getAbsolutePath) + + sc.listJars().size should be (1) + sc.listJars().head should include (tmpJar.getName) + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index b2eded43ba71..dcf83cb530a9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} import scala.concurrent.duration._ import scala.language.postfixOps @@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var server: HistoryServer = null private var port: Int = -1 - def init(): Unit = { + def init(extraConf: (String, String)*): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() val securityManager = HistoryServer.createSecurityManager(conf) @@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } + test("ui and api authorization checks") { + val appId = "app-20161115172038-0000" + val owner = "jose" + val admin = "root" + val other = "alice" + + stop() + init( + "spark.ui.filters" -> classOf[FakeAuthFilter].getName(), + "spark.history.ui.acls.enable" -> "true", + "spark.history.ui.admin.acls" -> admin) + + val tests = Seq( + (owner, HttpServletResponse.SC_OK), + (admin, HttpServletResponse.SC_OK), + (other, HttpServletResponse.SC_FORBIDDEN), + // When the remote user is null, the code behaves as if auth were disabled. + (null, HttpServletResponse.SC_OK)) + + val port = server.boundPort + val testUrls = Seq( + s"http://localhost:$port/api/v1/applications/$appId/jobs", + s"http://localhost:$port/history/$appId/jobs/") + + tests.foreach { case (user, expectedCode) => + testUrls.foreach { url => + val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil + val sc = TestUtils.httpResponseCode(new URL(url), headers = headers) + assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)") + } + } + } + def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) } @@ -648,3 +683,26 @@ object HistoryServerSuite { } } } + +/** + * A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header. + */ +class FakeAuthFilter extends Filter { + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val wrapped = new HttpServletRequestWrapper(hreq) { + override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER) + } + chain.doFilter(wrapped, res) + } + +} + +object FakeAuthFilter { + val FAKE_HTTP_USER = "HTTP_USER" +} diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f94baaa30d18..8150fff2d018 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -17,54 +17,43 @@ package org.apache.spark.executor +import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map +import scala.concurrent.duration._ -import org.mockito.Matchers._ -import org.mockito.Mockito.{mock, when} +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, TaskDescription} +import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId -class ExecutorSuite extends SparkFunSuite { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock(classOf[SparkEnv]) - val mockRpcEnv = mock(classOf[RpcEnv]) - val mockMetricsSystem = mock(classOf[MetricsSystem]) - val mockMemoryManager = mock(classOf[MemoryManager]) - when(mockEnv.conf).thenReturn(conf) - when(mockEnv.serializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) - when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) - when(mockEnv.closureSerializer).thenReturn(serializer) - val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() - val serializedTask = serializer.newInstance().serialize( - new FakeTask(0, 0, Nil, fakeTaskMetrics)) - val taskDescription = new TaskDescription( - taskId = 0, - attemptNumber = 0, - executorId = "", - name = "", - index = 0, - addedFiles = Map[String, Long](), - addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -86,7 +75,7 @@ class ExecutorSuite extends SparkFunSuite { val executorSuiteHelper = new ExecutorSuiteHelper - val mockExecutorBackend = mock(classOf[ExecutorBackend]) + val mockExecutorBackend = mock[ExecutorBackend] when(mockExecutorBackend.statusUpdate(any(), any(), any())) .thenAnswer(new Answer[Unit] { var firstTime = true @@ -102,8 +91,8 @@ class ExecutorSuite extends SparkFunSuite { val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] - executorSuiteHelper.testFailedReason - = serializer.newInstance().deserialize(taskEndReason) + executorSuiteHelper.testFailedReason = + serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` executorSuiteHelper.latch3.countDown() } @@ -112,16 +101,20 @@ class ExecutorSuite extends SparkFunSuite { var executor: Executor = null try { - executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread executor.launchTask(mockExecutorBackend, taskDescription) - executorSuiteHelper.latch1.await() + if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { + fail("executor did not send first status update in time") + } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. executor.killAllTasks(true) executorSuiteHelper.latch2.countDown() - executorSuiteHelper.latch3.await() + if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { + fail("executor did not send second status update in time") + } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` assert(executorSuiteHelper.testFailedReason === TaskKilled) @@ -133,6 +126,204 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it + // may be a false positive. And we should call the uncaught exception handler. + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat + // the fetch failure as a false positive, and just do normal OOM handling. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val (failReason, uncaughtExceptionHandler) = + runTaskGetFailReasonAndExceptionHandler(taskDescription) + // make sure the task failure just looks like a OOM, not a fetch failure + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) + } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = createFakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg) + case _ => + fail(s"unexpected failure type: $failReason") + } + } + + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + + private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + } + + private def runTaskGetFailReasonAndExceptionHandler( + taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + val mockBackend = mock[ExecutorBackend] + val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(executor.numRunningTasks === 0) + } + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + val failReason = + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + (failReason, mockUncaughtExceptionHandler) + } +} + +class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + new Iterator[Int] { + override def hasNext: Boolean = true + override def next(): Int = { + throw new FetchFailedException( + bmAddress = BlockManagerId("1", "hostA", 1234), + shuffleId = 0, + mapId = 0, + reduceId = 0, + message = "fake fetch failure" + ) + } + } + } + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } +} + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +class FetchFailureHidingRDD( + sc: SparkContext, + val input: FetchFailureThrowingRDD, + throwOOM: Boolean) extends RDD[Int](input) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + val inItr = input.compute(split, context) + try { + Iterator(inItr.size) + } catch { + case t: Throwable => + if (throwOOM) { + throw new OutOfMemoryError("OOM while handling another exception") + } else { + throw new RuntimeException("User Exception that hides the original exception", t) + } + } + } + + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } } // Helps to test("SPARK-15963") @@ -145,3 +336,14 @@ private class ExecutorSuiteHelper { @volatile var taskState: TaskState = _ @volatile var testFailedReason: TaskFailedReason = _ } + +private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { + def writeExternal(out: ObjectOutput): Unit = {} + def readExternal(in: ObjectInput): Unit = { + throw new RuntimeException(NonDeserializableTask.errorMsg) + } +} + +private object NonDeserializableTask { + val errorMsg = "failure in deserialization" +} diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index becf3829e724..5d522189a0c2 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -259,7 +259,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext test("output metrics on records written") { val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath + val filePath = file.toURI.toURL.toString val records = runAndReturnRecordsWritten { sc.parallelize(1 to numRecords).saveAsTextFile(filePath) @@ -269,7 +269,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext test("output metrics on records written - new Hadoop API") { val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath + val filePath = file.toURI.toURL.toString val records = runAndReturnRecordsWritten { sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c735220da2e1..8eaf9dfcf49b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1569,24 +1569,45 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } - test("run trivial shuffle with out-of-band failure and retry") { + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from a task that ran on that executor. We want to make sure the + * stage is resubmitted so that the task that ran on the failed executor is re-executed, and + * that the stage is only marked as finished once that task completes. + */ + test("run trivial shuffle with out-of-band executor failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) - // blockManagerMaster.removeExecutor("exec-hostA") - // pretend we were told hostA went away + // Tell the DAGScheduler that hostA was lost. runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) - // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks - // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers the + // stage complete), but the tasks that ran on HostA need to be re-run, so the DAGScheduler + // should re-submit the stage with one task (the task that originally ran on HostA). + assert(taskSets.size === 2) + assert(taskSets(1).tasks.size === 1) + + // Make sure that the stage that was re-submitted was the ShuffleMapStage (not the reduce + // stage, which shouldn't be run until all of the tasks in the ShuffleMapStage complete on + // alive executors). + assert(taskSets(1).tasks(0).isInstanceOf[ShuffleMapTask]) + // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // Make sure that the reduce stage was now submitted. + assert(taskSets.size === 3) + assert(taskSets(2).tasks(0).isInstanceOf[ResultTask[_, _]]) + + // Complete the reduce stage. complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -2031,6 +2052,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported * as done until all tasks have completed. + * + * Most of the functionality in this test is tested in "run trivial shuffle with out-of-band + * executor failure and retry". However, that test uses ShuffleMapStages that are followed by + * a ResultStage, whereas in this test, the ShuffleMapStage is tested in isolation, without a + * ResultStage after it. */ test("map stage submission with executor failure late map task completions") { val shuffleMapRdd = new MyRDD(sc, 3, Nil) @@ -2042,7 +2068,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet - // Pretend host A was lost + // Pretend host A was lost. This will cause the TaskSetManager to resubmit task 0, because it + // completed on hostA. val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch @@ -2054,13 +2081,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // A completion from another task should work because it's a non-failed host runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers + // the stage complete), but the task that ran on hostA needs to be re-run, so the map stage + // shouldn't be marked as complete, and the DAGScheduler should re-submit the stage. + assert(results.size === 0) + assert(taskSets.size === 2) // Now complete tasks in the second task set val newTaskSet = taskSets(1) - assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on hostA + // 2 tasks should have been re-submitted, for tasks 0 and 1 (which ran on hostA). + assert(newTaskSet.tasks.size === 2) + // Complete task 0 from the original task set (i.e., not hte one that's currently active). + // This should still be counted towards the job being complete (but there's still one + // outstanding task). runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + assert(results.size === 0) + + // Complete the final task, from the currently active task set. There's still one + // running task, task 0 in the currently active stage attempt, but the success of task 0 means + // the DAGScheduler can mark the stage as finished. runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 0c362b881d91..83ed12752074 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -195,6 +195,17 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, 0 until rdd.partitions.size) } + + test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { + val stage: Int = 1 + val partition: Int = 1 + val failedAttempt: Int = 0 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d03a0c990a02..2c2cda9f318e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -28,6 +28,7 @@ import org.mockito.Mockito.{mock, never, spy, verify, when} import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -664,6 +665,67 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } + test("[SPARK-13931] taskSetManager should not send Resubmitted tasks after being a zombie") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val singleTask = new ShuffleMapTask(0, 0, null, new Partition { + override def index: Int = 0 + }, Seq(TaskLocation("host1", "execA")), new Properties, null) + val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + // Offer host1, which should be accepted as a PROCESS_LOCAL location + // by the one task in the task set + val task1 = manager.resourceOffer("execA", "host1", TaskLocality.PROCESS_LOCAL).get + + // Mark the task as available for speculation, and then offer another resource, + // which should be used to launch a speculative copy of the task. + manager.speculatableTasks += singleTask.partitionId + val task2 = manager.resourceOffer("execB", "host2", TaskLocality.ANY).get + + assert(manager.runningTasks === 2) + assert(manager.isZombie === false) + + val directTaskResult = new DirectTaskResult[String](null, Seq()) { + override def value(resultSer: SerializerInstance): String = "" + } + // Complete one copy of the task, which should result in the task set manager + // being marked as a zombie, because at least one copy of its only task has completed. + manager.handleSuccessfulTask(task1.taskId, directTaskResult) + assert(manager.isZombie === true) + assert(resubmittedTasks === 0) + assert(manager.runningTasks === 1) + + manager.executorLost("execB", "host2", new SlaveLost()) + assert(manager.runningTasks === 0) + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f4bfdc2fd69a..ccede34b8cb4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -37,32 +37,31 @@ import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ -/** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends SparkFunSuite - with Matchers - with BeforeAndAfter - with LocalSparkContext { - - private val conf = new SparkConf(false).set("spark.app.id", "test") - private var rpcEnv: RpcEnv = null - private var master: BlockManagerMaster = null - private val securityMgr = new SecurityManager(conf) - private val bcastManager = new BroadcastManager(true, conf, securityMgr) - private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) - private val shuffleManager = new SortShuffleManager(conf) +trait BlockManagerReplicationBehavior extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + val conf: SparkConf + protected var rpcEnv: RpcEnv = null + protected var master: BlockManagerMaster = null + protected lazy val securityMgr = new SecurityManager(conf) + protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) + protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. - private val allStores = new ArrayBuffer[BlockManager] + protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer", "1m") - private val serializer = new KryoSerializer(conf) + + protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - private def makeBlockManager( + protected def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) @@ -355,7 +354,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite * is correct. Then it also drops the block from memory of each store (using LRU) and * again checks whether the master's knowledge gets updated. */ - private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { + protected def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { import org.apache.spark.storage.StorageLevel._ assert(maxReplication > 1, @@ -448,3 +447,61 @@ class BlockManagerReplicationSuite extends SparkFunSuite } } } + +class BlockManagerReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") +} + +class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.storage.replication.proactive", "true") + conf.set("spark.storage.exceptionOnPinLeak", "true") + + (2 to 5).foreach{ i => + test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { + testProactiveReplication(i) + } + } + + def testProactiveReplication(replicationFactor: Int) { + val blockSize = 1000 + val storeSize = 10000 + val initialStores = (1 to 10).map { i => makeBlockManager(storeSize, s"store$i") } + + val blockId = "a1" + + val storageLevel = StorageLevel(true, true, false, true, replicationFactor) + initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + + val blockLocations = master.getLocations(blockId) + logInfo(s"Initial locations : $blockLocations") + + assert(blockLocations.size === replicationFactor) + + // remove a random blockManager + val executorsToRemove = blockLocations.take(replicationFactor - 1) + logInfo(s"Removing $executorsToRemove") + executorsToRemove.foreach{exec => + master.removeExecutor(exec.executorId) + // giving enough time for replication to happen and new block be reported to master + Thread.sleep(200) + } + + // giving enough time for replication complete and locks released + Thread.sleep(500) + + val newLocations = master.getLocations(blockId).toSet + logInfo(s"New locations : $newLocations") + assert(newLocations.size === replicationFactor) + // there should only be one common block manager between initial and new locations + assert(newLocations.intersect(blockLocations.toSet).size === 1) + + // check if all the read locks have been released + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) + assert(locks.size === 0, "Read locks unreleased!") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 43f77e68c153..8ed09749ffd5 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -200,7 +200,10 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.bytesToString(2097152) === "2.0 MB") assert(Utils.bytesToString(2306867) === "2.2 MB") assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") + assert(Utils.bytesToString(5L * (1L << 40)) === "5.0 TB") + assert(Utils.bytesToString(5L * (1L << 50)) === "5.0 PB") + assert(Utils.bytesToString(5L * (1L << 60)) === "5.0 EB") + assert(Utils.bytesToString(BigInt(1L << 11) * (1L << 60)) === "2.36E+21 B") } test("copyStream") { @@ -507,6 +510,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4""", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") } + assertResolves(",jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2") } test("nonLocalPaths") { diff --git a/docs/building-spark.md b/docs/building-spark.md index 56b892696ee2..8353b7a520b8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -132,20 +132,6 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m $ cd core $ ../build/mvn scala:cc -## Speeding up Compilation with Zinc - -[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental -compiler. When run locally as a background process, it speeds up builds of Scala-based projects -like Spark. Developers who regularly recompile Spark with Maven will be the most interested in -Zinc. The project site gives instructions for building and running `zinc`; OS X users can -install it using `brew install zinc`. - -If using the `build/mvn` package `zinc` will automatically be downloaded and leveraged for all -builds. This process will auto-start after the first time `build/mvn` is called and bind to port -3030 unless the `ZINC_PORT` environment variable is set. The `zinc` process can subsequently be -shut down at any time by running `build/zinc-/bin/zinc -shutdown` and will automatically -restart whenever `build/mvn` is called. - ## Building with SBT Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. @@ -159,8 +145,14 @@ can be set to control the SBT build. For example: To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command -prompt. For more recommendations on reducing build time, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). +prompt. + +## Speeding up Compilation + +Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc +(for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for +developers who build with SBT). For more information about how to do this, refer to the +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems @@ -190,29 +182,16 @@ The following is an example of a command to run the tests: ./build/mvn test -The ScalaTest plugin also supports running only a specific Scala test suite as follows: - - ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.ReplSuite test - ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.* test - -or a Java test: - - ./build/mvn test -P... -DwildcardSuites=none -Dtest=org.apache.spark.streaming.JavaAPISuite - ## Testing with SBT The following is an example of a command to run the tests: ./build/sbt test -To run only a specific test suite as follows: - - ./build/sbt "test-only org.apache.spark.repl.ReplSuite" - ./build/sbt "test-only org.apache.spark.repl.*" - -To run test suites of a specific sub project as follows: +## Running Individual Tests - ./build/sbt core/test +For information about how to run individual tests, refer to the +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#running-individual-tests). ## PySpark pip installable diff --git a/docs/configuration.md b/docs/configuration.md index 2fcb3a096aea..63392a741a1f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1000,6 +1000,15 @@ Apart from these, the following properties are also available, and may be useful storage space to unroll the new block in its entirety. + + spark.storage.replication.proactive + false + + Enables proactive block replication for RDD blocks. Cached RDD block replicas lost due to + executor failures are replenished if there are any existing available replicas. This tries + to get the replication level of the block to the initial number. + + ### Execution Behavior diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index bb6f616b18a2..896f9302ef30 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -15,8 +15,8 @@ possible**. We recommend the following: * If at all possible, run Spark on the same nodes as HDFS. The simplest way is to set up a Spark [standalone mode cluster](spark-standalone.html) on the same nodes, and configure Spark and Hadoop's memory and CPU usage to avoid interference (for Hadoop, the relevant options are -`mapred.child.java.opts` for the per-task memory and `mapred.tasktracker.map.tasks.maximum` -and `mapred.tasktracker.reduce.tasks.maximum` for number of tasks). Alternatively, you can run +`mapred.child.java.opts` for the per-task memory and `mapreduce.tasktracker.map.tasks.maximum` +and `mapreduce.tasktracker.reduce.tasks.maximum` for number of tasks). Alternatively, you can run Hadoop and Spark on a common cluster manager like [Mesos](running-on-mesos.html) or [Hadoop YARN](running-on-yarn.html). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 782ee5818893..ab6f587e09ef 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -363,6 +363,50 @@ Refer to the [R API docs](api/R/spark.mlp.html) for more details. +## Linear Support Vector Machine + +A [support vector machine](https://en.wikipedia.org/wiki/Support_vector_machine) constructs a hyperplane +or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, +regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has +the largest distance to the nearest training-data points of any class (so-called functional margin), +since in general the larger the margin the lower the generalization error of the classifier. LinearSVC +in Spark ML supports binary classification with linear SVM. Internally, it optimizes the +[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) using OWLQN optimizer. + + +**Examples** + +
+ +
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.LinearSVC) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LinearSVCExample.scala %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/LinearSVC.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLinearSVCExample.java %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.LinearSVC) for more details. + +{% include_example python/ml/linearsvc.py %} +
+ +
+ +Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. + +{% include_example r/ml/svmLinear.R %} +
+ +
## One-vs-Rest classifier (a.k.a. One-vs-All) @@ -585,6 +629,11 @@ others. Continuous Inverse*, Idenity, Log + + Tweedie + Zero-inflated continuous + Power link function + * Canonical Link diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index cfe835172ab4..58f2d4b531e7 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -59,6 +59,34 @@ This approach is named "ALS-WR" and discussed in the paper It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. +### Cold-start strategy + +When making predictions using an `ALSModel`, it is common to encounter users and/or items in the +test dataset that were not present during training the model. This typically occurs in two +scenarios: + +1. In production, for new users or items that have no rating history and on which the model has not +been trained (this is the "cold start problem"). +2. During cross-validation, the data is split between training and evaluation sets. When using +simple random splits as in Spark's `CrossValidator` or `TrainValidationSplit`, it is actually +very common to encounter users and/or items in the evaluation set that are not in the training set + +By default, Spark assigns `NaN` predictions during `ALSModel.transform` when a user and/or item +factor is not present in the model. This can be useful in a production system, since it indicates +a new user or item, and so the system can make a decision on some fallback to use as the prediction. + +However, this is undesirable during cross-validation, since any `NaN` predicted values will result +in `NaN` results for the evaluation metric (for example when using `RegressionEvaluator`). +This makes model selection impossible. + +Spark allows users to set the `coldStartStrategy` parameter +to "drop" in order to drop any rows in the `DataFrame` of predictions that contain `NaN` values. +The evaluation metric will then be computed over the non-`NaN` data and will be valid. +Usage of this parameter is illustrated in the example below. + +**Note:** currently the supported cold start strategies are "nan" (the default behavior mentioned +above) and "drop". Further strategies may be supported in future. + **Examples**
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index 7cbb14654e9d..aa92c0a37c0f 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -132,7 +132,7 @@ The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +If the `Pipeline` had more `Estimator`s, it would call the `LogisticRegressionModel`'s `transform()` method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. diff --git a/docs/monitoring.md b/docs/monitoring.md index 7ba4824d463f..80519525af0c 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -381,6 +381,10 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/streaming/batches/[batch-id]/operations/[outputOp-id] Details of the given operation and given batch. + + + /applications/[app-id]/environment + Environment details of the given application. diff --git a/docs/quick-start.md b/docs/quick-start.md index 04ac27876252..aa4319a23325 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -260,7 +260,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -273,7 +273,7 @@ scalaVersion := "{{site.SCALA_VERSION}}" libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" {% endhighlight %} -For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `simple.sbt` +For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` according to the typical directory structure. Once that is in place, we can create a JAR package containing the application's code, then use the `spark-submit` script to run our program. @@ -281,7 +281,7 @@ containing the application's code, then use the `spark-submit` script to run our # Your directory layout should look like this $ find . . -./simple.sbt +./build.sbt ./src ./src/main ./src/main/scala diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cf95b95afd2e..e9ddaa76a797 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -604,3 +604,18 @@ spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spn Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log will include a list of all tokens obtained, and their expiry details + +## Using the Spark History Server to replace the Spark Web UI + +It is possible to use the Spark History Server application page as the tracking URL for running +applications when the application UI is disabled. This may be desirable on secure clusters, or to +reduce the memory usage of the Spark driver. To set up tracking through the Spark History Server, +do the following: + +- On the application side, set spark.yarn.historyServer.allowTracking=true in Spark's + configuration. This will tell Spark to use the history server's URL as the tracking URL if + the application's UI is disabled. +- On the Spark History Server, add org.apache.spark.deploy.yarn.YarnProxyRedirectFilter + to the list of filters in the spark.ui.filters configuration. + +Be aware that the history server information may not be up-to-date with the application's state. diff --git a/docs/security.md b/docs/security.md index a4796767832b..9eda42888637 100644 --- a/docs/security.md +++ b/docs/security.md @@ -12,7 +12,7 @@ Spark currently supports authentication via a shared secret. Authentication can ## Web UI The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.https.enabled` setting. +and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). ### Authentication diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 235f5ecc40c9..b077575155eb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -386,8 +386,8 @@ For example: The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc. -While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in -[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and +While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in +[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and [Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets. Moreover, users are not limited to the predefined aggregate functions and can create their own. @@ -397,7 +397,7 @@ Moreover, users are not limited to the predefined aggregate functions and can cr
-Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) +Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) abstract class to implement a custom untyped aggregate function. For example, a user-defined average can look like: @@ -888,8 +888,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
@@ -901,8 +902,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
@@ -913,8 +915,9 @@ This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %} @@ -926,8 +929,9 @@ files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} @@ -1410,7 +1414,7 @@ Thrift JDBC server also supports sending thrift RPC messages over HTTP transport Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: hive.server2.transport.mode - Set this to value: http - hive.server2.thrift.http.port - HTTP port number fo listen on; default is 10001 + hive.server2.thrift.http.port - HTTP port number to listen on; default is 10001 hive.server2.http.endpoint - HTTP endpoint; default is cliservice To test, use beeline to connect to the JDBC/ODBC server in http mode with: diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index ad3b2fb26dd6..6af47b6efba2 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -392,7 +392,7 @@ data, thus relieving the users from reasoning about it. As an example, let’s see how this model handles event-time based processing and late arriving data. ## Handling Event-time and Late Data -Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of events every minute) to be just a special type of grouping and aggregation on the even-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier. +Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of events every minute) to be just a special type of grouping and aggregation on the event-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier. Furthermore, this model naturally handles data that has arrived later than expected based on its event-time. Since Spark is updating the Result Table, @@ -401,7 +401,7 @@ as well as cleaning up old aggregates to limit the size of intermediate state data. Since Spark 2.1, we have support for watermarking which allows the user to specify the threshold of late data, and allows the engine to accordingly clean up old state. These are explained later in more -details in the [Window Operations](#window-operations-on-event-time) section. +detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) @@ -647,7 +647,7 @@ df.groupBy("deviceType").count() ### Window Operations on Event Time -Aggregations over a sliding event-time window are straightforward with Structured Streaming. The key idea to understand about window-based aggregations are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration. +Aggregations over a sliding event-time window are straightforward with Structured Streaming and are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration. Imagine our [quick example](#quick-example) is modified and the stream now contains lines along with the time when the line was generated. Instead of running word counts, we want to count words within 10 minute windows, updating every 5 minutes. That is, word counts in words received between 10 minute windows 12:00 - 12:10, 12:05 - 12:15, 12:10 - 12:20, etc. Note that 12:00 - 12:10 means data that arrived after 12:00 but before 12:10. Now, consider a word that was received at 12:07. This word should increment the counts corresponding to two windows 12:00 - 12:10 and 12:05 - 12:15. So the counts will be indexed by both, the grouping key (i.e. the word) and the window (can be calculated from the event-time). @@ -713,7 +713,7 @@ old windows correctly, as illustrated below. ![Handling Late Data](img/structured-streaming-late-data.png) -However, to run this query for days, its necessary for the system to bound the amount of +However, to run this query for days, it's necessary for the system to bound the amount of intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 33ba668b32fc..81970b7c81f4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -103,6 +103,8 @@ public static void main(String[] args) { ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop"); Dataset predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator() diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java index 4594e3462b2a..ff917b720c8b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample + * bin/run-example ml.JavaBucketedRandomProjectionLSHExample */ public class JavaBucketedRandomProjectionLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java new file mode 100644 index 000000000000..a18ed1d0b48f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.classification.LinearSVC; +import org.apache.spark.ml.classification.LinearSVCModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off$ + +public class JavaLinearSVCExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaLinearSVCExample") + .getOrCreate(); + + // $example on$ + // Load training data + Dataset training = spark.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LinearSVC lsvc = new LinearSVC() + .setMaxIter(10) + .setRegParam(0.1); + + // Fit the model + LinearSVCModel lsvcModel = lsvc.fit(training); + + // Print the coefficients and intercept for LinearSVC + System.out.println("Coefficients: " + + lsvcModel.coefficients() + " Intercept: " + lsvcModel.intercept()); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java index 0aace4693925..e164598e3ef8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample + * bin/run-example ml.JavaMinHashLSHExample */ public class JavaMinHashLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index 3f809eba7fff..a0979aa2d24e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -27,7 +27,6 @@ import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -69,7 +68,8 @@ public static void main(String[] args) { .setOutputCol("words") .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - spark.udf().register("countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); + spark.udf().register( + "countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); Dataset tokenized = tokenizer.transform(sentenceDataFrame); tokenized.select("sentence", "words") diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index bd49f059b29f..dc9970d88527 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -118,7 +118,9 @@ public static void main(String[] args) { new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))); JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map(r -> - new Tuple2, Object>(new Tuple2<>(r.user(), r.product()), r.rating()) + new Tuple2, Object>( + new Tuple2<>(r.user(), r.product()), + r.rating()) )).join(predictions).values(); // Create regression metrics object diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index adb96dd8bf00..82bb284ea3e5 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -25,8 +25,6 @@ import java.util.Properties; // $example on:basic_parquet_example$ -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Encoders; // $example on:schema_merging$ @@ -217,12 +215,11 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an RDD[String] storing one JSON object per string. + // an Dataset[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); - JavaRDD anotherPeopleRDD = - new JavaSparkContext(spark.sparkContext()).parallelize(jsonData); - Dataset anotherPeople = spark.read().json(anotherPeopleRDD); + Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); + Dataset anotherPeople = spark.read().json(anotherPeopleDataset); anotherPeople.show(); // +---------------+----+ // | address|name| diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1a979ff5b5be..2e7214ed56f9 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -44,7 +44,9 @@ (training, test) = ratings.randomSplit([0.8, 0.2]) # Build the recommendation model using ALS on the training data - als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") + # Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating", + coldStartStrategy="drop") model = als.fit(training) # Evaluate the model by computing the RMSE on the test data diff --git a/examples/src/main/python/ml/linearsvc.py b/examples/src/main/python/ml/linearsvc.py new file mode 100644 index 000000000000..18cbf87a1069 --- /dev/null +++ b/examples/src/main/python/ml/linearsvc.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LinearSVC +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("linearSVC Example")\ + .getOrCreate() + + # $example on$ + # Load training data + training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lsvc = LinearSVC(maxIter=10, regParam=0.1) + + # Fit the model + lsvcModel = lsvc.fit(training) + + # Print the coefficients and intercept for linearsSVC + print("Coefficients: " + str(lsvcModel.coefficients)) + print("Intercept: " + str(lsvcModel.intercept)) + + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index ebcf66995b47..c07fa8f2752b 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -187,9 +187,6 @@ def programmatic_schema_example(spark): # Creates a temporary view using the DataFrame schemaPeople.createOrReplaceTempView("people") - # Creates a temporary view using the DataFrame - schemaPeople.createOrReplaceTempView("people") - # SQL can be run over DataFrames that have been registered as a table. results = spark.sql("SELECT name FROM people") diff --git a/examples/src/main/r/ml/bisectingKmeans.R b/examples/src/main/r/ml/bisectingKmeans.R index 5fb5bfb0fa5a..b3eaa6dd86d7 100644 --- a/examples/src/main/r/ml/bisectingKmeans.R +++ b/examples/src/main/r/ml/bisectingKmeans.R @@ -25,20 +25,21 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-bisectingKmeans-example") # $example on$ -irisDF <- createDataFrame(iris) +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Fit bisecting k-means model with four centers -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) # get fitted result from a bisecting k-means model fitted.model <- fitted(model, "centers") # Model summary -summary(fitted.model) +head(summary(fitted.model)) # fitted values on training data -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) # $example off$ sparkR.session.stop() diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index e41af97751d3..ee13910382c5 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -25,11 +25,12 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Model summary summary(gaussianGLM) @@ -39,14 +40,15 @@ gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) head(gaussianPredictions) # Fit a generalized linear model with glm (R-compliant) -gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") +gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -# Note: Filter out "setosa" from label column (two labels left) to match "binomial" family. -binomialDF <- filter(irisDF, irisDF$Species != "setosa") -binomialTestDF <- binomialDF -binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") +training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") +df_list2 <- randomSplit(training2, c(7,3), 2) +binomialDF <- df_list2[[1]] +binomialTestDF <- df_list2[[2]] +binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") # Model summary summary(binomialGLM) diff --git a/examples/src/main/r/ml/kmeans.R b/examples/src/main/r/ml/kmeans.R index 288e2f9724e0..824df20644fa 100644 --- a/examples/src/main/r/ml/kmeans.R +++ b/examples/src/main/r/ml/kmeans.R @@ -26,10 +26,12 @@ sparkR.session(appName = "SparkR-ML-kmeans-example") # $example on$ # Fit a k-means model with spark.kmeans -irisDF <- suppressWarnings(createDataFrame(iris)) -kmeansDF <- irisDF -kmeansTestDF <- irisDF -kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +df_list <- randomSplit(training, c(7,3), 2) +kmeansDF <- df_list[[1]] +kmeansTestDF <- df_list[[2]] +kmeansModel <- spark.kmeans(kmeansDF, ~ Class + Sex + Age + Freq, k = 3) # Model summary diff --git a/examples/src/main/r/ml/ml.R b/examples/src/main/r/ml/ml.R index b96819418bad..41b7867f64e3 100644 --- a/examples/src/main/r/ml/ml.R +++ b/examples/src/main/r/ml/ml.R @@ -26,11 +26,12 @@ sparkR.session(appName = "SparkR-ML-example") ############################ model read/write ############################################## # $example on:read_write$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index bb5d16360849..868f49b16f21 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -65,6 +65,8 @@ object ALSExample { val model = als.fit(training) // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop") val predictions = model.transform(test) val evaluator = new RegressionEvaluator() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala index 654535c264a3..16da4fa887aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.BucketedRandomProjectionLSHExample + * bin/run-example ml.BucketedRandomProjectionLSHExample */ object BucketedRandomProjectionLSHExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala new file mode 100644 index 000000000000..5f43e65712b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LinearSVC +// $example off$ +import org.apache.spark.sql.SparkSession + +object LinearSVCExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("LinearSVCExample") + .getOrCreate() + + // $example on$ + // Load training data + val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lsvc = new LinearSVC() + .setMaxIter(10) + .setRegParam(0.1) + + // Fit the model + val lsvcModel = lsvc.fit(training) + + // Print the coefficients and intercept for linear svc + println(s"Coefficients: ${lsvcModel.coefficients} Intercept: ${lsvcModel.intercept}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala index 6c1e22268ad2..b94ab9b8bedc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.MinHashLSHExample + * bin/run-example ml.MinHashLSHExample */ object MinHashLSHExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 66f7cb1b53f4..381e69cda841 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -111,6 +111,10 @@ object SQLDataSourceExample { private def runJsonDatasetExample(spark: SparkSession): Unit = { // $example on:json_dataset$ + // Primitive types (Int, String, etc) and Product types (case classes) encoders are + // supported by importing this when creating a Dataset. + import spark.implicits._ + // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files val path = "examples/src/main/resources/people.json" @@ -135,10 +139,10 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an RDD[String] storing one JSON object per string - val otherPeopleRDD = spark.sparkContext.makeRDD( + // an Dataset[String] storing one JSON object per string + val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) - val otherPeople = spark.read.json(otherPeopleRDD) + val otherPeople = spark.read.json(otherPeopleDataset) otherPeople.show() // +---------------+----+ // | address|name| diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 673d60ff6f87..68bc3e3e2e9a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { @@ -147,6 +148,9 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon } test("test late binding start offsets") { + // Kafka fails to remove the logs on Windows. See KAFKA-1194. + assume(!Utils.isWindows) + var kafkaUtils: KafkaTestUtils = null try { /** diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 4f82b133cb4c..534fb77c9ce1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -161,11 +162,12 @@ class KafkaSourceSuite extends KafkaSourceTest { // Make sure Spark 2.1.0 will throw an exception when reading the new log intercept[java.lang.IllegalArgumentException] { // Simulate how Spark 2.1.0 reads the log - val in = new FileInputStream(metadataPath.getAbsolutePath + "/0") - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) + Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => + val length = in.read() + val bytes = new Array[Byte](length) + in.read(bytes) + KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) + } } } } @@ -181,13 +183,13 @@ class KafkaSourceSuite extends KafkaSourceTest { "subscribe" -> topic ) - val from = Paths.get( - getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").getPath) + val from = new File( + getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") Files.copy(from, to) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) + val source = provider.createSource( + spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) val deserializedOffset = source.getOffset.get val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) assert(referenceOffset == deserializedOffset) diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index b2bac7c938ab..daa79e79163b 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -58,6 +58,11 @@ amazon-kinesis-client ${aws.kinesis.client.version} + + com.amazonaws + aws-java-sdk-sts + ${aws.java.sdk.version} + com.amazonaws amazon-kinesis-producer diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index d40bd3ff560d..626bde48e1a8 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.regex.Pattern; -import com.amazonaws.regions.RegionUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; @@ -127,7 +126,7 @@ public static void main(String[] args) throws Exception { // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl); // Setup the Spark config and StreamingContext SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala new file mode 100644 index 000000000000..2eebd6130d4d --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import scala.collection.JavaConverters._ + +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.AmazonKinesis + +private[streaming] object KinesisExampleUtils { + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index a70c13d7d68a..f14117b708a0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging { // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl) // Setup the SparkConfig and StreamingContext val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 45dc3c388cb8..23c4d99e50f5 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val awsCredentialsOption: Option[SerializableAWSCredentials] = None + val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = awsCredentialsOption.getOrElse { - new DefaultAWSCredentialsProviderChain().getCredentials() - } + val credentials = kinesisCredsProvider.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) @@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator( private var lastSeqNumber: String = null private var internalIterator: Iterator[Record] = null - client.setEndpoint(endpointUrl, "kinesis", regionId) + client.setEndpoint(endpointUrl) override protected def getNext(): Record = { var nextRecord: Record = null diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index c445c15a5f64..5fb83b26f838 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -21,7 +21,7 @@ import java.util.concurrent._ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import org.apache.spark.internal.Logging import org.apache.spark.streaming.Duration diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 5223c81a8e0e..fbc6b99443ed 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials] + kinesisCredsProvider: SerializableCredentialsProvider ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - awsCredentialsOption = awsCredentialsOption) + kinesisCredsProvider = kinesisCredsProvider) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) + checkpointAppName, checkpointInterval, storageLevel, messageHandler, + kinesisCredsProvider) } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 393e56a39320..13fc54e531dd 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -private[kinesis] -case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) - extends AWSCredentials { - override def getAWSAccessKeyId: String = accessKeyId - override def getAWSSecretKey: String = secretKey -} - /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: @@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies - * the credentials + * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to + * generate the AWSCredentialsProvider instance used for KCL + * authorization. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials]) + kinesisCredsProvider: SerializableCredentialsProvider) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) - // KCL config instance - val awsCredProvider = resolveAWSCredentialsProvider() - val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) - .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) - .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) + val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( + checkpointAppName, + streamName, + kinesisCredsProvider.provider, + workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) /* * RecordProcessorFactory creates impls of IRecordProcessor. @@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T]( } } - /** - * If AWS credential is provided, return a AWSCredentialProvider returning that credential. - * Otherwise, return the DefaultAWSCredentialsProviderChain. - */ - private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { - awsCredentialsOption match { - case Some(awsCredentials) => - logInfo("Using provided AWS credentials") - new AWSCredentialsProvider { - override def getCredentials: AWSCredentials = awsCredentials - override def refresh(): Unit = { } - } - case None => - logInfo("Using DefaultAWSCredentialsProviderChain") - new DefaultAWSCredentialsProviderChain() - } - } - - /** * Class to handle blocks generated by this receiver's block generator. Specifically, in * the context of the Kinesis Receiver, this handler does the following. diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 73ccc4ad23f6..8c6a399dd763 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f183ef00b33c..73ac7a3cd235 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB -import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient} import com.amazonaws.services.kinesis.model._ import org.apache.spark.internal.Logging @@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl) private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils { val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } + lazy val shouldRunTests = { val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") if (isEnvSet) { diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index b2daffa34ccb..2d777982e760 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -73,7 +73,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, None) + cleanedHandler, DefaultCredentialsProvider) } } @@ -123,9 +123,80 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + cleanedHandler, kinesisCredsProvider) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = STSCredentialsProvider( + stsRoleArn = stsAssumeRoleArn, + stsSessionName = stsSessionName, + stsExternalId = Option(stsExternalId), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey)) + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, kinesisCredsProvider) } } @@ -169,7 +240,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, None) + defaultMessageHandler, DefaultCredentialsProvider) } } @@ -213,9 +284,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + defaultMessageHandler, kinesisCredsProvider) } } @@ -319,6 +393,68 @@ object KinesisUtils { awsAccessKeyId, awsSecretKey) } + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -404,10 +540,6 @@ object KinesisUtils { defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } - private def getRegionByEndpoint(endpointUrl: String): String = { - RegionUtils.getRegionByEndpoint(endpointUrl).getName() - } - private def validateRegion(regionName: String): String = { Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") @@ -439,6 +571,7 @@ private class KinesisUtilsPythonHelper { } } + // scalastyle:off def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -449,22 +582,43 @@ private class KinesisUtilsPythonHelper { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = { + // scalastyle:on + if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) + && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) { + throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " + + "must all be defined or all be null") + } + + if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + stsAssumeRoleArn, stsSessionName, stsExternalId) + } else { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + } + + // Throw IllegalArgumentException unless both values are null or neither are. + private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) { if (awsAccessKeyId == null && awsSecretKey != null) { throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") } if (awsAccessKeyId != null && awsSecretKey == null) { throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") } - if (awsAccessKeyId == null && awsSecretKey == null) { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) - } else { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - awsAccessKeyId, awsSecretKey) - } } - } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala new file mode 100644 index 000000000000..aa6fe12edf74 --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentialsProvider + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentialsProvider( + awsAccessKeyId: String, + awsSecretKey: String) extends SerializableCredentialsProvider with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultAWSCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentialsProvider( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index f078973c6c28..b37b08746792 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kinesis; -import com.amazonaws.regions.RegionUtils; import com.amazonaws.services.kinesis.model.Record; import org.junit.Test; @@ -36,7 +35,7 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { @Test public void testKinesisStream() { String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); - String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); // Tests the API, does not actually test data receiving JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", @@ -45,6 +44,17 @@ dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration( ssc.stop(); } + @Test + public void testAwsCreds() { + String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); + + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), + StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey"); + ssc.stop(); + } private static Function handler = new Function() { @Override @@ -62,4 +72,27 @@ public void testCustomHandler() { ssc.stop(); } + + @Test + public void testCustomHandlerAwsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey"); + + ssc.stop(); + } + + @Test + public void testCustomHandlerAwsStsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", + "fakeSTSExternalId"); + + ssc.stop(); + } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 800502a77d12..deb411d73e58 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -22,7 +22,7 @@ import java.util.Arrays import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} @@ -62,9 +62,26 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of SerializableAWSCredentials") { - Utils.deserialize[SerializableAWSCredentials]( - Utils.serialize(new SerializableAWSCredentials("x", "y"))) + test("check serializability of credential provider classes") { + Utils.deserialize[BasicCredentialsProvider]( + Utils.serialize(BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y"))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId")))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId"), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y")))) } test("process records including store and set checkpointer") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 404b673c0117..387a96f26b30 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -49,7 +49,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Dummy parameters for API testing private val dummyEndpointUrl = defaultEndpointUrl - private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl) private val dummyAWSAccessKey = "dummyAccessKey" private val dummyAWSSecretKey = "dummySecretKey" @@ -138,8 +138,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.awsCredentialsOption === - Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + awsAccessKeyId = dummyAWSAccessKey, + awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } val partitions = nonEmptyRDD.partitions.map { @@ -201,7 +202,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) stream shouldBe a [ReceiverInputDStream[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index bf6e76d7ac44..f76b14eeeb54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -440,19 +440,14 @@ private class LinearSVCAggregator( private val numFeatures: Int = bcFeaturesStd.value.length private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures - private val coefficients: Vector = bcCoefficients.value private var weightSum: Double = 0.0 private var lossSum: Double = 0.0 - require(numFeaturesPlusIntercept == coefficients.size, s"Dimension mismatch. Coefficients " + - s"length ${coefficients.size}, FeaturesStd length ${numFeatures}, fitIntercept: $fitIntercept") - - private val coefficientsArray = coefficients match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") } - private val gradientSumArray = Array.fill[Double](coefficientsArray.length)(0) + private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept) /** * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient @@ -463,6 +458,9 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -530,18 +528,15 @@ private class LinearSVCAggregator( this } - def loss: Double = { - if (weightSum != 0) { - lossSum / weightSum - } else 0.0 - } + def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0 def gradient: Vector = { if (weightSum != 0) { val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) result - } else Vectors.dense(Array.fill[Double](coefficientsArray.length)(0)) + } else { + Vectors.dense(new Array[Double](numFeaturesPlusIntercept)) + } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 892e00fa6041..1a78187d4f8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1431,7 +1431,12 @@ private class LogisticAggregator( private var weightSum = 0.0 private var lossSum = 0.0 - private val gradientSumArray = Array.fill[Double](coefficientSize)(0.0D) + @transient private lazy val coefficientsArray: Array[Double] = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " + + s"got type ${bcCoefficients.value.getClass}.)") + } + private lazy val gradientSumArray = new Array[Double](coefficientSize) if (multinomial && numClasses <= 2) { logInfo(s"Multinomial logistic regression for binary classification yields separate " + @@ -1447,7 +1452,7 @@ private class LogisticAggregator( label: Double): Unit = { val localFeaturesStd = bcFeaturesStd.value - val localCoefficients = bcCoefficients.value + val localCoefficients = coefficientsArray val localGradientArray = gradientSumArray val margin = - { var sum = 0.0 @@ -1491,7 +1496,7 @@ private class LogisticAggregator( logistic regression without pivoting. */ val localFeaturesStd = bcFeaturesStd.value - val localCoefficients = bcCoefficients.value + val localCoefficients = coefficientsArray val localGradientArray = gradientSumArray // marginOfLabel is margins(label) in the formula diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ea2dc6cfd8d3..a9c1a7ba0bc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -580,10 +580,10 @@ private class ExpectationAggregator( private val k: Int = bcWeights.value.length private var totalCnt: Long = 0L private var newLogLikelihood: Double = 0.0 - private val newWeights: Array[Double] = new Array[Double](k) - private val newMeans: Array[DenseVector] = Array.fill(k)( + private lazy val newWeights: Array[Double] = new Array[Double](k) + private lazy val newMeans: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures)(0.0))) - private val newCovs: Array[DenseVector] = Array.fill(k)( + private lazy val newCovs: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0))) @transient private lazy val oldGaussians = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index bbcef3502d1d..55720e2d613d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -437,6 +437,9 @@ abstract class LDAModel private[ml] ( @Since("1.6.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) + @Since("2.2.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + /** @group setParam */ @Since("1.6.0") def setSeed(value: Long): this.type = set(seed, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala new file mode 100644 index 000000000000..417968d9b817 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, + FPGrowth => MLlibFPGrowth} +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Common params for FPGrowth and FPGrowthModel + */ +private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { + + /** + * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears + * more than (minSupport * size-of-the-dataset) times will be output + * Default: 0.3 + * @group param + */ + @Since("2.2.0") + val minSupport: DoubleParam = new DoubleParam(this, "minSupport", + "the minimal support level of a frequent pattern", + ParamValidators.inRange(0.0, 1.0)) + setDefault(minSupport -> 0.3) + + /** @group getParam */ + @Since("2.2.0") + def getMinSupport: Double = $(minSupport) + + /** + * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and + * partition number of the input dataset is used. + * @group expertParam + */ + @Since("2.2.0") + val numPartitions: IntParam = new IntParam(this, "numPartitions", + "Number of partitions used by parallel FP-growth", ParamValidators.gtEq[Int](1)) + + /** @group expertGetParam */ + @Since("2.2.0") + def getNumPartitions: Int = $(numPartitions) + + /** + * Minimal confidence for generating Association Rule. + * Note that minConfidence has no effect during fitting. + * Default: 0.8 + * @group param + */ + @Since("2.2.0") + val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence", + "minimal confidence for generating Association Rule", + ParamValidators.inRange(0.0, 1.0)) + setDefault(minConfidence -> 0.8) + + /** @group getParam */ + @Since("2.2.0") + def getMinConfidence: Double = $(minConfidence) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + @Since("2.2.0") + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(featuresCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + } +} + +/** + * :: Experimental :: + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * Li et al., PFP: Parallel FP-Growth for Query + * Recommendation. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * Han et al., Mining frequent patterns without + * candidate generation. Note null values in the feature column are ignored during fit(). + * + * @see + * Association rule learning (Wikipedia) + */ +@Since("2.2.0") +@Experimental +class FPGrowth @Since("2.2.0") ( + @Since("2.2.0") override val uid: String) + extends Estimator[FPGrowthModel] with FPGrowthParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("fpgrowth")) + + /** @group setParam */ + @Since("2.2.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** @group expertSetParam */ + @Since("2.2.0") + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + @Since("2.2.0") + def setMinConfidence(value: Double): this.type = set(minConfidence, value) + + /** @group setParam */ + @Since("2.2.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + @Since("2.2.0") + override def fit(dataset: Dataset[_]): FPGrowthModel = { + transformSchema(dataset.schema, logging = true) + genericFit(dataset) + } + + private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val data = dataset.select($(featuresCol)) + val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) + if (isSet(numPartitions)) { + mllibFP.setNumPartitions($(numPartitions)) + } + val parentModel = mllibFP.run(items) + val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) + + val schema = StructType(Seq( + StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + } + + @Since("2.2.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.2.0") + override def copy(extra: ParamMap): FPGrowth = defaultCopy(extra) +} + + +@Since("2.2.0") +object FPGrowth extends DefaultParamsReadable[FPGrowth] { + + @Since("2.2.0") + override def load(path: String): FPGrowth = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by FPGrowth. + * + * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + */ +@Since("2.2.0") +@Experimental +class FPGrowthModel private[ml] ( + @Since("2.2.0") override val uid: String, + @transient val freqItemsets: DataFrame) + extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { + + /** @group setParam */ + @Since("2.2.0") + def setMinConfidence(value: Double): this.type = set(minConfidence, value) + + /** @group setParam */ + @Since("2.2.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and + * "consequent" are Array[T] and "confidence" is Double. + */ + @Since("2.2.0") + @transient lazy val associationRules: DataFrame = { + AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + } + + /** + * The transform method first generates the association rules according to the frequent itemsets. + * Then for each association rule, it will examine the input items against antecedents and + * summarize the consequents as prediction. The prediction column has the same data type as the + * input column(Array[T]) and will not contain existing items in the input column. The null + * values in the feature columns are treated as empty sets. + * WARNING: internally it collects association rules to the driver and uses broadcast for + * efficiency. This may bring pressure to driver memory for large set of association rules. + */ + @Since("2.2.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + genericTransform(dataset) + } + + private def genericTransform(dataset: Dataset[_]): DataFrame = { + val rules: Array[(Seq[Any], Seq[Any])] = associationRules.select("antecedent", "consequent") + .rdd.map(r => (r.getSeq(0), r.getSeq(1))) + .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] + val brRules = dataset.sparkSession.sparkContext.broadcast(rules) + + val dt = dataset.schema($(featuresCol)).dataType + // For each rule, examine the input items and summarize the consequents + val predictUDF = udf((items: Seq[_]) => { + if (items != null) { + val itemset = items.toSet + brRules.value.flatMap(rule => + if (items != null && rule._1.forall(item => itemset.contains(item))) { + rule._2.filter(item => !itemset.contains(item)) + } else { + Seq.empty + }) + } else { + Seq.empty + }.distinct }, dt) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("2.2.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.2.0") + override def copy(extra: ParamMap): FPGrowthModel = { + val copied = new FPGrowthModel(uid, freqItemsets) + copyValues(copied, extra).setParent(this.parent) + } + + @Since("2.2.0") + override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this) +} + +@Since("2.2.0") +object FPGrowthModel extends MLReadable[FPGrowthModel] { + + @Since("2.2.0") + override def read: MLReader[FPGrowthModel] = new FPGrowthModelReader + + @Since("2.2.0") + override def load(path: String): FPGrowthModel = super.load(path) + + /** [[MLWriter]] instance for [[FPGrowthModel]] */ + private[FPGrowthModel] + class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.freqItemsets.write.parquet(dataPath) + } + } + + private class FPGrowthModelReader extends MLReader[FPGrowthModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[FPGrowthModel].getName + + override def load(path: String): FPGrowthModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val frequentItems = sparkSession.read.parquet(dataPath) + val model = new FPGrowthModel(metadata.uid, frequentItems) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +private[fpm] object AssociationRules { + + /** + * Computes the association rules with confidence above minConfidence. + * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from + * algorithms like [[FPGrowth]]. + * @param itemsCol column name for frequent itemsets + * @param freqCol column name for frequent itemsets count + * @param minConfidence minimum confidence for the result association rules + * @return a DataFrame("antecedent", "consequent", "confidence") containing the association + * rules. + */ + def getAssociationRulesFromFP[T: ClassTag]( + dataset: Dataset[_], + itemsCol: String, + freqCol: String, + minConfidence: Double): DataFrame = { + + val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd + .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) + val rows = new MLlibAssociationRules() + .setMinConfidence(minConfidence) + .run(freqItemSetRdd) + .map(r => Row(r.antecedent, r.consequent, r.confidence)) + + val dt = dataset.schema(itemsCol).dataType + val schema = StructType(Seq( + StructField("antecedent", dt, nullable = false), + StructField("consequent", dt, nullable = false), + StructField("confidence", DoubleType, nullable = false))) + val rules = dataset.sparkSession.createDataFrame(rows, schema) + rules + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index bd965acf5694..0bf543d88894 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg } - def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + def fit( + formula: String, + data: DataFrame, + aggregationDepth: Int): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) .setFeaturesCol(rFormula.getFeaturesCol) + .setAggregationDepth(aggregationDepth) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 78f401f29b00..cbd6cd1c7933 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) - .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + + if (weightCol != null) glr.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 48632316f395..d31ebb46afb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper val isotonicRegression = new IsotonicRegression() .setIsotonic(isotonic) .setFeatureIndex(featureIndex) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) + if (weightCol != null) isotonicRegression.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 645bc7247f30..c96f99cb8343 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String + weightCol: String, + aggregationDepth: Int ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper .setFitIntercept(fitIntercept) .setFamily(family) .setStandardization(standardization) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + .setAggregationDepth(aggregationDepth) if (thresholds.length > 1) { lr.setThresholds(thresholds) @@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper lr.setThreshold(thresholds(0)) } + if (weightCol != null) lr.setWeightCol(weightCol) + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 97c865529860..799e881fad74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -80,16 +80,47 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is - * out of integer range. + * out of integer range or contains a fractional part. */ - protected val checkedCast = udf { (n: Double) => - if (n > Int.MaxValue || n < Int.MinValue) { - throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + - s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") - } else { - n.toInt + protected[recommendation] val checkedCast = udf { (n: Any) => + n match { + case v: Int => v // Avoid unnecessary casting + case v: Number => + val intV = v.intValue + // Checks if number within Int range and has no fractional part. + if (v.doubleValue == intV) { + intV + } else { + throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " + + s"Value $n was either out of Integer range or contained a fractional part that " + + s"could not be converted.") + } + case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.") } } + + /** + * Param for strategy for dealing with unknown or new users/items at prediction time. + * This may be useful in cross-validation or production scenarios, for handling user/item ids + * the model has not seen in the training data. + * Supported values: + * - "nan": predicted value for unknown ids will be NaN. + * - "drop": rows in the input DataFrame containing unknown ids will be dropped from + * the output DataFrame containing predictions. + * Default: "nan". + * @group expertParam + */ + val coldStartStrategy = new Param[String](this, "coldStartStrategy", + "strategy for dealing with unknown or new users/items at prediction time. This may be " + + "useful in cross-validation or production scenarios, for handling user/item ids the model " + + "has not seen in the training data. Supported values: " + + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", + (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + + /** @group expertGetParam */ + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase } /** @@ -203,7 +234,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, - intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") + intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", + coldStartStrategy -> "nan") /** * Validates and transforms the input schema. @@ -248,6 +280,10 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) @@ -260,13 +296,19 @@ class ALSModel private[ml] ( Float.NaN } } - dataset + val predictions = dataset .join(userFactors, - checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + checkedCast(dataset($(userCol))) === userFactors("id"), "left") .join(itemFactors, - checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") + checkedCast(dataset($(itemCol))) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + getColdStartStrategy match { + case ALSModel.Drop => + predictions.na.drop("all", Seq($(predictionCol))) + case ALSModel.NaN => + predictions + } } @Since("1.3.0") @@ -290,6 +332,10 @@ class ALSModel private[ml] ( @Since("1.6.0") object ALSModel extends MLReadable[ALSModel] { + private val NaN = "nan" + private val Drop = "drop" + private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader @@ -432,6 +478,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @@ -451,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(checkedCast(col($(userCol)).cast(DoubleType)), - checkedCast(col($(itemCol)).cast(DoubleType)), r) + .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) @@ -671,7 +720,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { numUserBlocks: Int = 10, numItemBlocks: Int = 10, maxIter: Int = 10, - regParam: Double = 1.0, + regParam: Double = 0.1, implicitPrefs: Boolean = false, alpha: Double = 1.0, nonnegative: Boolean = false, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af..094853b6f480 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params fitting: Boolean): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { - SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { @@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { - dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) - .rdd.map { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), + col($(censorCol)).cast(DoubleType)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } @@ -526,7 +526,7 @@ private class AFTAggregator( private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](length) + private lazy val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fdeadaf27497..110764dc074f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -109,6 +109,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the index in the power link function. Only applicable for the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. + * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" + * package. * * @group param */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a6c29433d730..529f66eadbcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures */ final val isotonic: BooleanParam = new BooleanParam(this, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + + "whether the output sequence should be isotonic/increasing (true) or " + "antitonic/decreasing (false)") /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2de7e81d8d41..45df1d9be647 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -959,7 +959,7 @@ private class LeastSquaresAggregator( @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 @transient private lazy val offset = effectiveCoefAndOffset._2 - private val gradientSumArray = Array.ofDim[Double](dim) + private lazy val gradientSumArray = Array.ofDim[Double](dim) /** * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index f3bace818157..4c525c0714ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -226,7 +226,7 @@ private[spark] object GradientBoostedTrees extends Logging { (a, b) => treesIndices.map(idx => a(idx) + b(idx))) .map(_ / dataCount) - broadcastTrees.destroy() + broadcastTrees.destroy(blocking = false) evaluation.toArray } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 7a714db85335..efedebe30138 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -261,7 +261,7 @@ object LBFGS extends Logging { val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp) // broadcasted model is not needed anymore - bcW.destroy() + bcW.destroy(blocking = false) /** * regVal is sum of weight squares if it's L2 updater; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index fc1d4125a564..b1e82656a240 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -167,7 +167,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( (a, b) => treesIndices.map(idx => a(idx) + b(idx))) .map(_ / dataCount) - broadcastTrees.destroy() + broadcastTrees.destroy(blocking = false) evaluation.toArray } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 0f71deb9ea52..d2fe6bb2ca71 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -33,7 +33,8 @@ public class JavaDecisionTreeSuite extends SharedSparkSession { - private static int validatePrediction(List validationData, DecisionTreeModel model) { + private static int validatePrediction( + List validationData, DecisionTreeModel model) { int numCorrect = 0; for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index ee2aefee7a6d..a165d8a9345c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -123,6 +123,21 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model2.intercept !== 0.0) } + test("sparse coefficients in SVCAggregator") { + val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) + val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true) + val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrown.getMessage.contains("coefficients only supports dense")) + + bcCoefficients.destroy(blocking = false) + bcFeaturesStd.destroy(blocking = false) + } + test("linearSVC with sample weights") { def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = { assert(m1.coefficients ~== m2.coefficients absTol 0.05) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 43547a4aafcb..d89a958eed45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -456,6 +456,32 @@ class LogisticRegressionSuite assert(blrModel.intercept !== 0.0) } + test("sparse coefficients in LogisticAggregator") { + val bcCoefficientsBinary = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) + val binaryAgg = new LogisticAggregator(bcCoefficientsBinary, bcFeaturesStd, 2, + fitIntercept = true, multinomial = false) + val thrownBinary = withClue("binary logistic aggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + binaryAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrownBinary.getMessage.contains("coefficients only supports dense")) + + val bcCoefficientsMulti = spark.sparkContext.broadcast(Vectors.sparse(6, Array(0), Array(1.0))) + val multinomialAgg = new LogisticAggregator(bcCoefficientsMulti, bcFeaturesStd, 3, + fitIntercept = true, multinomial = true) + val thrown = withClue("multinomial logistic aggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + multinomialAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrown.getMessage.contains("coefficients only supports dense")) + bcCoefficientsBinary.destroy(blocking = false) + bcFeaturesStd.destroy(blocking = false) + bcCoefficientsMulti.destroy(blocking = false) + } + test("overflow prediction for multiclass") { val model = new LogisticRegressionModel("mLogReg", Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala new file mode 100644 index 000000000000..74c746140190 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = FPGrowthSuite.getFPGrowthData(spark) + } + + test("FPGrowth fit and transform with different data types") { + Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => + val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val model = new FPGrowth().setMinSupport(0.5).fit(data) + val generatedRules = model.setMinConfidence(0.5).associationRules + val expectedRules = spark.createDataFrame(Seq( + (Array("2"), Array("1"), 1.0), + (Array("1"), Array("2"), 0.75) + )).toDF("antecedent", "consequent", "confidence") + .withColumn("antecedent", col("antecedent").cast(ArrayType(dt))) + .withColumn("consequent", col("consequent").cast(ArrayType(dt))) + assert(expectedRules.sort("antecedent").rdd.collect().sameElements( + generatedRules.sort("antecedent").rdd.collect())) + + val transformed = model.transform(data) + val expectedTransformed = spark.createDataFrame(Seq( + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "2"), Array.emptyIntArray), + (0, Array("1", "3"), Array(2)) + )).toDF("id", "features", "prediction") + .withColumn("features", col("features").cast(ArrayType(dt))) + .withColumn("prediction", col("prediction").cast(ArrayType(dt))) + assert(expectedTransformed.collect().toSet.equals( + transformed.collect().toSet)) + } + } + + test("FPGrowth getFreqItems") { + val model = new FPGrowth().setMinSupport(0.7).fit(dataset) + val expectedFreq = spark.createDataFrame(Seq( + (Array("1"), 4L), + (Array("2"), 3L), + (Array("1", "2"), 3L), + (Array("2", "1"), 3L) // duplicate as the items sequence is not guaranteed + )).toDF("items", "expectedFreq") + val freqItems = model.freqItemsets + + val checkDF = freqItems.join(expectedFreq, "items") + assert(checkDF.count() == 3 && checkDF.filter(col("freq") === col("expectedFreq")).count() == 3) + } + + test("FPGrowth getFreqItems with Null") { + val df = spark.createDataFrame(Seq( + (1, Array("1", "2", "3", "5")), + (2, Array("1", "2", "3", "4")), + (3, null.asInstanceOf[Array[String]]) + )).toDF("id", "features") + val model = new FPGrowth().setMinSupport(0.7).fit(dataset) + val prediction = model.transform(df) + assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) + } + + test("FPGrowth parameter check") { + val fpGrowth = new FPGrowth().setMinSupport(0.4567) + val model = fpGrowth.fit(dataset) + .setMinConfidence(0.5678) + assert(fpGrowth.getMinSupport === 0.4567) + assert(model.getMinConfidence === 0.5678) + } + + test("read/write") { + def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { + assert(model.freqItemsets.sort("items").collect() === + model2.freqItemsets.sort("items").collect()) + } + val fPGrowth = new FPGrowth() + testEstimatorAndModelReadWrite( + fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + } + +} + +object FPGrowthSuite { + + def getFPGrowthData(spark: SparkSession): DataFrame = { + spark.createDataFrame(Seq( + (0, Array("1", "2")), + (0, Array("1", "2")), + (0, Array("1", "2")), + (0, Array("1", "3")) + )).toDF("id", "features") + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "minSupport" -> 0.321, + "minConfidence" -> 0.456, + "numPartitions" -> 5, + "predictionCol" -> "myPrediction" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b923bacce23c..c8228dd00437 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -205,6 +206,70 @@ class ALSSuite assert(decompressed.toSet === expected) } + test("CheckedCast") { + val checkedCast = new ALS().checkedCast + val df = spark.range(1) + + withClue("Valid Integer Ids") { + df.select(checkedCast(lit(123))).collect() + } + + withClue("Valid Long Ids") { + df.select(checkedCast(lit(1231L))).collect() + } + + withClue("Valid Decimal Ids") { + df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect() + } + + withClue("Valid Double Ids") { + df.select(checkedCast(lit(123.0))).collect() + } + + val msg = "either out of Integer range or contained a fractional part" + withClue("Invalid Long: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000L))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Type") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit("123.1"))).collect() + } + assert(e.getMessage.contains("was not numeric")) + } + } + /** * Generates an explicit feedback dataset for testing ALS. * @param numUsers number of users @@ -498,8 +563,8 @@ class ALSSuite (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== - act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== + act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } // check user/item ids falling outside of Int range @@ -510,34 +575,35 @@ class ALSSuite (0, big, small, 0, big, small, 2.0), (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + val msg = "either out of Integer range or contained a fractional part" withClue("fit should fail when ids exceed integer range. ") { assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) assert(intercept[SparkException] { model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) } } @@ -547,6 +613,53 @@ class ALSSuite ALS.train(ratings) } } + + test("ALS cold start user/item prediction strategy") { + val spark = this.spark + import spark.implicits._ + import org.apache.spark.sql.functions._ + + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val data = ratings.toDF + val knownUser = data.select(max("user")).as[Int].first() + val unknownUser = knownUser + 10 + val knownItem = data.select(max("item")).as[Int].first() + val unknownItem = knownItem + 20 + val test = Seq( + (unknownUser, unknownItem), + (knownUser, unknownItem), + (unknownUser, knownItem), + (knownUser, knownItem) + ).toDF("user", "item") + + val als = new ALS().setMaxIter(1).setRank(1) + // default is 'nan' + val defaultModel = als.fit(data) + val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() + assert(defaultPredictions.length == 4) + assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) + assert(!defaultPredictions.last.isNaN) + + // check 'drop' strategy should filter out rows with unknown users/items + val dropPredictions = defaultModel + .setColdStartStrategy("drop") + .transform(test) + .select("prediction").as[Float].collect() + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + } + + test("case insensitive cold start param value") { + val spark = this.spark + import spark.implicits._ + val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1) + val data = ratings.toDF + val model = new ALS().fit(data) + Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => + model.setColdStartStrategy(s).transform(data) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38..3cd4b0ac308e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types._ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite } } - test("should support all NumericType labels") { + test("should support all NumericType labels, and not support other types") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( aft, spark, isClassification = false) { (expected, actual) => @@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType censors, and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF("label", "features") + .withColumn("censor", lit(0.0)) + val aft = new AFTSurvivalRegression().setMaxIter(1) + val expected = aft.fit(df) + + val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0)) + types.foreach { t => + val actual = aft.fit(df.select(col("label"), col("features"), + col("censor").cast(t))) + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + + val dfWithStringCensors = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3), "0") + )).toDF("label", "features", "censor") + val thrown = intercept[IllegalArgumentException] { + aft.fit(dfWithStringCensors) + } + assert(thrown.getMessage.contains( + "Column censor must be of type NumericType but was actually of type StringType")) + } + test("numerical stability of standardization") { val trainer = new AFTSurvivalRegression() val model1 = trainer.fit(datasetUnivariate) diff --git a/pom.xml b/pom.xml index 60e4c7269eaf..c1174593c192 100644 --- a/pom.xml +++ b/pom.xml @@ -145,7 +145,9 @@ 1.7.7 hadoop2 0.9.3 - 1.6.2 + 1.7.3 + + 1.11.76 0.10.2 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9d359427f27a..56b8c0b95e8a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( + // [SPARK-19652][UI] Do auth checks for REST API access. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"), + // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), @@ -51,6 +55,9 @@ object MimaExcludes { // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"), + // [SPARK-19267] Fetch Failure handling robust to user error handling + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"), + // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$10"), diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9331e74eede5..14c51a306e1c 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -93,13 +93,15 @@ def keyword_only(func): """ A decorator that forces keyword arguments in the wrapped method and saves actual input keyword arguments in `_input_kwargs`. + + .. note:: Should only be used to wrap a method where first arg is `self` """ @wraps(func) - def wrapper(*args, **kwargs): - if len(args) > 1: + def wrapper(self, *args, **kwargs): + if len(args) > 0: raise TypeError("Method %s forces keyword arguments." % func.__name__) - wrapper._input_kwargs = kwargs - return func(*args, **kwargs) + self._input_kwargs = kwargs + return func(self, **kwargs) return wrapper diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ac4b2b035f5c..2961cda553d6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,10 +173,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v - if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: - # disable randomness of hash of string in worker, if this is not - # launched by spark-submit - self.environment["PYTHONHASHSEED"] = "0" + + self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", "0") # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ac40fceaf8e9..b4fc357e42d7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -124,7 +124,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "org.apache.spark.ml.classification.LinearSVC", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True, standardization=True, threshold=0.0, aggregationDepth=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -140,7 +140,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre aggregationDepth=2): Sets params for Linear SVM Classifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -266,7 +266,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -286,7 +286,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) self._checkThresholdConsistency() return self @@ -760,7 +760,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -778,7 +778,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre seed=None) Sets params for the DecisionTreeClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -890,7 +890,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -908,7 +908,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) Sets params for linear classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1031,7 +1031,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1047,7 +1047,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) Sets params for Gradient Boosted Tree Classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1174,7 +1174,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) self._setDefault(smoothing=1.0, modelType="multinomial") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1188,7 +1188,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1329,7 +1329,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1343,7 +1343,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre solver="l-bfgs", initialWeights=None) Sets params for MultilayerPerceptronClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1519,7 +1519,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred classifier=None) """ super(OneVsRest, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -1529,7 +1529,7 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): Sets params for OneVsRest. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index c6c1a0033190..88ac7e275e38 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -224,7 +224,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", self.uid) self._setDefault(k=2, tol=0.01, maxIter=100) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -240,7 +240,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for GaussianMixture. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -414,7 +414,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -430,7 +430,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for KMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -591,7 +591,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", self.uid) self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -603,7 +603,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", maxIter= seed=None, k=4, minDivisibleClusterSize=1.0) Sets params for BisectingKMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -916,7 +916,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, subsamplingRate=0.05, optimizeDocConcentration=True, topicDistributionCol="topicDistribution", keepLastCheckpoint=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -941,7 +941,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt Sets params for LDA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 7aa16fa5b90f..7cb8d62f212c 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -148,7 +148,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -174,7 +174,7 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") Sets params for binary classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -226,7 +226,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -252,7 +252,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse") Sets params for regression evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -299,7 +299,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.5.0") @@ -325,7 +325,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="f1") Sets params for multiclass classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c2eafbefcdec..92f8549e9cb9 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -94,7 +94,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self._setDefault(threshold=0.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -104,7 +104,7 @@ def setParams(self, threshold=0.0, inputCol=None, outputCol=None): setParams(self, threshold=0.0, inputCol=None, outputCol=None) Sets params for this Binarizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -258,14 +258,14 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None): """ - __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, + __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, \ bucketLength=None) """ super(BucketedRandomProjectionLSH, self).__init__() self._java_obj = \ self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -277,7 +277,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None) Sets params for this BucketedRandomProjectionLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.2.0") @@ -370,7 +370,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -380,7 +380,7 @@ def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="e setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -484,7 +484,7 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -496,7 +496,7 @@ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, input outputCol=None) Set the params for the CountVectorizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -616,7 +616,7 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) self._setDefault(inverse=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -626,7 +626,7 @@ def setParams(self, inverse=False, inputCol=None, outputCol=None): setParams(self, inverse=False, inputCol=None, outputCol=None) Sets params for this DCT. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -680,7 +680,7 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -690,7 +690,7 @@ def setParams(self, scalingVec=None, inputCol=None, outputCol=None): setParams(self, scalingVec=None, inputCol=None, outputCol=None) Sets params for this ElementwiseProduct. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -750,7 +750,7 @@ def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=N super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -760,7 +760,7 @@ def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol= setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None) Sets params for this HashingTF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -823,7 +823,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self._setDefault(minDocFreq=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -833,7 +833,7 @@ def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): setParams(self, minDocFreq=0, inputCol=None, outputCol=None) Sets params for this IDF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -913,7 +913,7 @@ def __init__(self, inputCol=None, outputCol=None): super(MaxAbsScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) self._setDefault() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -923,7 +923,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this MaxAbsScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1011,7 +1011,7 @@ def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): super(MinHashLSH, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1021,7 +1021,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1) Sets params for this MinHashLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1106,7 +1106,7 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) self._setDefault(min=0.0, max=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1116,7 +1116,7 @@ def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) Sets params for this MinMaxScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1224,7 +1224,7 @@ def __init__(self, n=2, inputCol=None, outputCol=None): super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) self._setDefault(n=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1234,7 +1234,7 @@ def setParams(self, n=2, inputCol=None, outputCol=None): setParams(self, n=2, inputCol=None, outputCol=None) Sets params for this NGram. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -1288,7 +1288,7 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self._setDefault(p=2.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1298,7 +1298,7 @@ def setParams(self, p=2.0, inputCol=None, outputCol=None): setParams(self, p=2.0, inputCol=None, outputCol=None) Sets params for this Normalizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1363,12 +1363,12 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, @keyword_only def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ - __init__(self, includeFirst=True, inputCol=None, outputCol=None) + __init__(self, dropLast=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) self._setDefault(dropLast=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1378,7 +1378,7 @@ def setParams(self, dropLast=True, inputCol=None, outputCol=None): setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1434,7 +1434,7 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self._setDefault(degree=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1444,7 +1444,7 @@ def setParams(self, degree=2, inputCol=None, outputCol=None): setParams(self, degree=2, inputCol=None, outputCol=None) Sets params for this PolynomialExpansion. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1540,7 +1540,7 @@ def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0. self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", self.uid) self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1552,7 +1552,7 @@ def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0 handleInvalid="error") Set the params for the QuantileDiscretizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -1665,7 +1665,7 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1677,7 +1677,7 @@ def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None, toLowercase=True) Sets params for this RegexTokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1768,7 +1768,7 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1778,7 +1778,7 @@ def setParams(self, statement=None): setParams(self, statement=None) Sets params for this SQLTransformer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1847,7 +1847,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self._setDefault(withMean=False, withStd=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1857,7 +1857,7 @@ def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) Sets params for this StandardScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1963,7 +1963,7 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1973,7 +1973,7 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -2021,7 +2021,7 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2031,7 +2031,7 @@ def setParams(self, inputCol=None, outputCol=None, labels=None): setParams(self, inputCol=None, outputCol=None, labels=None) Sets params for this IndexToString. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2085,7 +2085,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive= self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), caseSensitive=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2095,7 +2095,7 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) Sets params for this StopWordRemover. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2178,7 +2178,7 @@ def __init__(self, inputCol=None, outputCol=None): """ super(Tokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2188,7 +2188,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this Tokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2222,7 +2222,7 @@ def __init__(self, inputCols=None, outputCol=None): """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2232,7 +2232,7 @@ def setParams(self, inputCols=None, outputCol=None): setParams(self, inputCols=None, outputCol=None) Sets params for this VectorAssembler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2320,7 +2320,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self._setDefault(maxCategories=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2330,7 +2330,7 @@ def setParams(self, maxCategories=20, inputCol=None, outputCol=None): setParams(self, maxCategories=20, inputCol=None, outputCol=None) Sets params for this VectorIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2435,7 +2435,7 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) self._setDefault(indices=[], names=[]) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2445,7 +2445,7 @@ def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): setParams(self, inputCol=None, outputCol=None, indices=None, names=None): Sets params for this VectorSlicer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2558,7 +2558,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, windowSize=5, maxSentenceLength=1000) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2570,7 +2570,7 @@ def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000) Sets params for this Word2Vec. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2718,7 +2718,7 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2728,7 +2728,7 @@ def setParams(self, k=None, inputCol=None, outputCol=None): setParams(self, k=None, inputCol=None, outputCol=None) Set params for this PCA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -2858,7 +2858,7 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label", super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) self._setDefault(forceIndexLabel=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2870,7 +2870,7 @@ def setParams(self, formula=None, featuresCol="features", labelCol="label", forceIndexLabel=False) Sets params for RFormula. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -3017,7 +3017,7 @@ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05, fwe=0.05) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -3031,7 +3031,7 @@ def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, fdr=0.05, fwe=0.05) Sets params for this ChiSqSelector. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.1.0") diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index dc3d23ff1661..99d8fa3a5b73 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -372,6 +372,7 @@ def copy(self, extra=None): extra = dict() that = copy.copy(self) that._paramMap = {} + that._defaultParamMap = {} return self._copyValues(that, extra) def _shouldOwn(self, param): @@ -452,12 +453,16 @@ def _copyValues(self, to, extra=None): :param extra: extra params to be copied :return: the target instance with param values copied """ - if extra is None: - extra = dict() - paramMap = self.extractParamMap(extra) - for p in self.params: - if p in paramMap and to.hasParam(p.name): - to._set(**{p.name: paramMap[p]}) + paramMap = self._paramMap.copy() + if extra is not None: + paramMap.update(extra) + for param in self.params: + # copy default params + if param in self._defaultParamMap and to.hasParam(param.name): + to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param] + # copy explicitly set params + if param in paramMap and to.hasParam(param.name): + to._set(**{param.name: paramMap[param]}) return to def _resetUid(self, newUid): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a78e3b49fbcf..4aac6a4466b5 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -58,7 +58,7 @@ def __init__(self, stages=None): __init__(self, stages=None) """ super(Pipeline, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @since("1.3.0") @@ -85,7 +85,7 @@ def setParams(self, stages=None): setParams(self, stages=None) Sets params for Pipeline. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e28d38bd19f8..8bc899a0788b 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) @@ -145,8 +151,8 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK") - kwargs = self.__init__._input_kwargs + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -155,16 +161,16 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -332,6 +338,20 @@ def getFinalStorageLevel(self): """ return self.getOrDefault(self.finalStorageLevel) + @since("2.2.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b42e80706980..b199bf282e4f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -108,7 +108,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -122,7 +122,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre standardization=True, solver="auto", weightCol=None, aggregationDepth=2) Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -464,7 +464,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) self._setDefault(isotonic=True, featureIndex=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -475,7 +475,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre weightCol=None, isotonic=True, featureIndex=0): Set the params for IsotonicRegression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -704,7 +704,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -720,7 +720,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -895,7 +895,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, numTrees=20, featureSubsetStrategy="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -913,7 +913,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre featureSubsetStrategy="auto") Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1022,7 +1022,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1040,7 +1040,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance") Sets params for Gradient Boosted Tree Regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1171,7 +1171,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], maxIter=100, tol=1E-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1186,7 +1186,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ quantilesCol=None, aggregationDepth=2): """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1366,7 +1366,7 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1380,7 +1380,7 @@ def setParams(self, labelCol="label", featuresCol="features", predictionCol="pre regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) Sets params for generalized linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 53204cde29b7..352416055791 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -250,7 +250,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -259,7 +259,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -271,7 +271,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(OtherTestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -280,7 +280,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -389,6 +389,22 @@ def test_word2vec_param(self): # Check windowSize is set properly self.assertEqual(model.getWindowSize(), 6) + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + class EvaluatorTests(SparkSessionTestCase): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2dcc99cef8aa..ffeb4459e1aa 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -186,7 +186,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -198,7 +198,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num seed=None): Sets params for cross validator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -346,7 +346,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @@ -358,7 +358,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, tra seed=None): Sets params for the train validation split. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b384b2b50733..a5e6e2b05496 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -68,7 +68,8 @@ def portable_hash(x): >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ - if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + + if sys.version_info >= (3, 2, 3) and 'PYTHONHASHSEED' not in os.environ: raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") if x is None: diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0df187a9d3c3..c10ab9638a21 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -248,6 +248,7 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + contains = _bin_op("contains") rlike = _bin_op("rlike") like = _bin_op("like") startswith = _bin_op("startsWith") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 70efeaf0160c..bb6df2268209 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1158,6 +1158,12 @@ def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. + For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming + :class:`DataFrame`, it will keep all data across triggers as intermediate state to drop + duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can + be and system will accordingly limit the state. In addition, too late data older than + watermark will be dropped to avoid any possibility of duplicates. + :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d2617203140f..426a4a8c93a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()): +----------+--------------+------------+ """ def _udf(f, returnType=StringType()): - return UserDefinedFunction(f, returnType) + udf_obj = UserDefinedFunction(f, returnType) + + @functools.wraps(f) + def wrapper(*args): + return udf_obj(*args) + + wrapper.func = udf_obj.func + wrapper.returnType = udf_obj.returnType + + return wrapper # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bed390e60c9..45fb9b759152 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -163,8 +163,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -191,10 +191,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -304,7 +307,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -366,11 +370,24 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] @@ -382,7 +399,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 965c8f6b269e..625fb9ba385a 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -433,8 +433,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -463,10 +463,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -558,7 +561,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -618,11 +622,24 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. + >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -636,7 +653,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9058443285ac..e943f8da3db1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,9 +266,6 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") - with self.assertRaises(ValueError): - data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() - def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -440,12 +437,19 @@ def test_udf_with_order_by_and_limit(self): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType @@ -578,6 +582,21 @@ def as_double(x): [2, 3.0, "FOO", "foo", "foo", 3, 1.0] ) + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -944,7 +963,8 @@ def test_column_operators(self): self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ + cs.startswith('a'), cs.endswith('a') self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) @@ -962,13 +982,6 @@ def test_column_select(self): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - def test_column_alias_metadata(self): - df = self.df - df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) - self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') - with self.assertRaises(AssertionError): - df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) - def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 3a8d8b819fd3..b839859c4525 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -37,7 +37,8 @@ class KinesisUtils(object): def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel=StorageLevel.MEMORY_AND_DISK_2, - awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder, + stsAssumeRoleArn=None, stsSessionName=None, stsExternalId=None): """ Create an input stream that pulls messages from a Kinesis stream. This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -67,6 +68,12 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, :param awsSecretKey: AWS SecretKey (default is None. If None, will use DefaultAWSCredentialsProviderChain) :param decoder: A function used to decode value (default is utf8_decoder) + :param stsAssumeRoleArn: ARN of IAM role to assume when using STS sessions to read from + the Kinesis stream (default is None). + :param stsSessionName: Name to uniquely identify STS sessions used to read from Kinesis + stream, if STS is being used (default is None). + :param stsExternalId: External ID that can be used to validate against the assumed IAM + role's trust policy, if STS is being used (default is None). :return: A DStream object """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) @@ -81,7 +88,8 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, raise jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, jduration, jlevel, - awsAccessKeyId, awsSecretKey) + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, + stsSessionName, stsExternalId) stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e908b1e739bb..c6c87a9ea555 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -58,6 +58,7 @@ from StringIO import StringIO +from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -1347,7 +1348,7 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapred.input.dir": hellopath} + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -1366,7 +1367,7 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapred.input.dir": hellopath} + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -1515,12 +1516,12 @@ def test_oldhadoop(self): conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapred.output.dir": basepath + "/olddataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapred.input.dir": basepath + "/olddataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1547,14 +1548,14 @@ def test_newhadoop(self): self.assertEqual(result, data) conf = { - "mapreduce.outputformat.class": + "mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.Text", - "mapred.output.dir": basepath + "/newdataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" } self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1584,16 +1585,16 @@ def test_newhadoop_with_array(self): self.assertEqual(result, array_data) conf = { - "mapreduce.outputformat.class": + "mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapred.output.dir": basepath + "/newdataset/" + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" } self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( conf, valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -1663,18 +1664,19 @@ def test_reserialization(self): conf4 = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.dir": basepath + "/reserialize/dataset"} + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} rdd.saveAsHadoopDataset(conf4) result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) self.assertEqual(result4, data) - conf5 = {"mapreduce.outputformat.class": + conf5 = {"mapreduce.job.outputformat.class": "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapred.output.dir": basepath + "/reserialize/newdataset"} + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } rdd.saveAsNewAPIHadoopDataset(conf5) result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) @@ -2160,6 +2162,44 @@ def test_memory_conf(self): sc.stop() +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/test_support/sql/ages_newlines.csv b/python/test_support/sql/ages_newlines.csv new file mode 100644 index 000000000000..d19f6731625f --- /dev/null +++ b/python/test_support/sql/ages_newlines.csv @@ -0,0 +1,6 @@ +Joe,20,"Hi, +I am Jeo" +Tom,30,"My name is Tom" +Hyukjin,25,"I am Hyukjin + +I love Spark!" diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 792ade8f0bdb..38b082ac0119 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, SparkUncaughtExceptionHandler, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -97,6 +97,7 @@ private[mesos] object MesosClusterDispatcher with CommandLineUtils { override def main(args: Array[String]) { + Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f555072c3842..f69c223ab9b6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -54,14 +54,17 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( with org.apache.mesos.Scheduler with MesosSchedulerUtils { - val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures + // Blacklist a slave after this many failures + private val MAX_SLAVE_FAILURES = 2 - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) - val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + // Maximum number of cores to acquire + private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) - val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + + private val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") @@ -75,10 +78,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new mutable.HashMap[String, Int] - val gpusByTaskId = new mutable.HashMap[String, Int] - var totalCoresAcquired = 0 - var totalGpusAcquired = 0 + private val coresByTaskId = new mutable.HashMap[String, Int] + private val gpusByTaskId = new mutable.HashMap[String, Int] + private var totalCoresAcquired = 0 + private var totalGpusAcquired = 0 // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because @@ -108,7 +111,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) + private val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = @@ -139,7 +142,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( securityManager.isAuthenticationEnabled()) } - var nextMesosTaskId = 0 + private var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -256,7 +259,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def sufficientResourcesRegistered(): Boolean = { - totalCoresAcquired >= maxCores * minRegisteredRatio + totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } override def disconnected(d: org.apache.mesos.SchedulerDriver) {} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index cdb3b6848965..78346e974495 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.Promise import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} @@ -37,8 +35,8 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.mesos.Utils._ @@ -304,25 +302,29 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } test("weburi is set in created scheduler driver") { - setBackend() + initializeSparkConf() + sc = new SparkContext(sparkConf) + val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val securityManager = mock[SecurityManager] val backend = new MesosCoarseGrainedSchedulerBackend( - taskScheduler, sc, "master", securityManager) { + taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( - masterUrl: String, - scheduler: Scheduler, - sparkUser: String, - appName: String, - conf: SparkConf, - webuiUrl: Option[String] = None, - checkpoint: Option[Boolean] = None, - failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = { + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { markRegistered() assert(webuiUrl.isDefined) assert(webuiUrl.get.equals("http://webui")) @@ -419,37 +421,11 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!dockerInfo.getForcePullImage) } - test("Do not call removeExecutor() after backend is stopped") { - setBackend() - - // launches a task on a valid offer - val offers = List(Resources(backend.executorMemory(sc), 1)) - offerResources(offers) - verifyTaskLaunched(driver, "o1") - - // launches a thread simulating status update - val statusUpdateThread = new Thread { - override def run(): Unit = { - while (!stopCalled) { - Thread.sleep(100) - } - - val status = createTaskStatus("0", "s1", TaskState.TASK_FINISHED) - backend.statusUpdate(driver, status) - } - }.start - - backend.stop() - // Any method of the backend involving sending messages to the driver endpoint should not - // be called after the backend is stopped. - verify(driverEndpoint, never()).askSync(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) - } - test("mesos supports spark.executor.uri") { val url = "spark.spark.spark.com" setBackend(Map( "spark.executor.uri" -> url - ), false) + ), null) val (mem, cpu) = (backend.executorMemory(sc), 4) @@ -465,7 +441,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map( "spark.mesos.fetcherCache.enable" -> "true", "spark.executor.uri" -> url - ), false) + ), null) val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) val launchedTasks = verifyTaskLaunched(driver, "o1") @@ -479,7 +455,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map( "spark.mesos.fetcherCache.enable" -> "false", "spark.executor.uri" -> url - ), false) + ), null) val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) val launchedTasks = verifyTaskLaunched(driver, "o1") @@ -504,8 +480,31 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(networkInfos.get(0).getName == "test-network-name") } + test("supports spark.scheduler.minRegisteredResourcesRatio") { + val expectedCores = 1 + setBackend(Map( + "spark.cores.max" -> expectedCores.toString, + "spark.scheduler.minRegisteredResourcesRatio" -> "1.0")) + + val offers = List(Resources(backend.executorMemory(sc), expectedCores)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + assert(!backend.isReady) + + registerMockExecutor(launchedTasks(0).getTaskId.getValue, "s1", expectedCores) + assert(backend.isReady) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) + private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { + val mockEndpointRef = mock[RpcEndpointRef] + val mockAddress = mock[RpcAddress] + val message = RegisterExecutor(executorId, mockEndpointRef, slaveId, cores, Map.empty) + + backend.driverEndpoint.askSync[Boolean](message) + } + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -534,8 +533,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver, - shuffleClient: MesosExternalShuffleClient, - endpoint: RpcEndpointRef): MesosCoarseGrainedSchedulerBackend = { + shuffleClient: MesosExternalShuffleClient) = { val securityManager = mock[SecurityManager] val backend = new MesosCoarseGrainedSchedulerBackend( @@ -553,9 +551,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient - override protected def createDriverEndpointRef( - properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint - // override to avoid race condition with the driver thread on `mesosDriver` override def startScheduler(newDriver: SchedulerDriver): Unit = { mesosDriver = newDriver @@ -571,31 +566,35 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend } - private def setBackend(sparkConfVars: Map[String, String] = null, - setHome: Boolean = true) { + private def initializeSparkConf( + sparkConfVars: Map[String, String] = null, + home: String = "/path"): Unit = { sparkConf = (new SparkConf) .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .set("spark.mesos.driver.webui.url", "http://webui") - if (setHome) { - sparkConf.setSparkHome("/path") + if (home != null) { + sparkConf.setSparkHome(home) } if (sparkConfVars != null) { sparkConf.setAll(sparkConfVars) } + } + private def setBackend(sparkConfVars: Map[String, String] = null, home: String = "/path") { + initializeSparkConf(sparkConfVars, home) sc = new SparkContext(sparkConf) driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] - driverEndpoint = mock[RpcEndpointRef] - when(driverEndpoint.ask(any())(any())).thenReturn(Promise().future) - backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient) } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9df43aea3f3d..864c834d110f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -332,7 +332,7 @@ private[spark] class ApplicationMaster( _sparkConf: SparkConf, _rpcEnv: RpcEnv, driverRef: RpcEndpointRef, - uiAddress: String, + uiAddress: Option[String], securityMgr: SecurityManager) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() @@ -408,8 +408,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl).getOrElse(""), - securityMgr) + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -435,7 +434,7 @@ private[spark] class ApplicationMaster( clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), + registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), securityMgr) // In client mode the actor will stop the reporter thread. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a00234c2b416..e86bd5459311 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -100,6 +100,7 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null private var credentials: Credentials = null + private var amKeytabFileName: String = null private val launcherBackend = new LauncherBackend() { override def onStopRequest(): Unit = { @@ -471,7 +472,7 @@ private[spark] class Client( logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") val (_, localizedPath) = distribute(keytab, - destName = sparkConf.get(KEYTAB), + destName = Some(amKeytabFileName), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") } @@ -708,6 +709,9 @@ private[spark] class Client( // Save Spark configuration to a file in the archive. val props = new Properties() sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + // Override spark.yarn.key to point to the location in distributed cache which will be used + // by AM. + Option(amKeytabFileName).foreach { k => props.setProperty(KEYTAB.key, k) } confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8) props.store(writer, "Spark configuration.") @@ -813,6 +817,7 @@ private[spark] class Client( sys.env.get(envname).foreach(env(envname) = _) } } + sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) } sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp => @@ -995,8 +1000,7 @@ private[spark] class Client( val f = new File(keytab) // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. - val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(KEYTAB.key, keytabFileName) + amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString sparkConf.set(PRINCIPAL.key, principal) } // Defensive copy of the credentials diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala new file mode 100644 index 000000000000..ae625df75362 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import org.apache.spark.internal.Logging + +/** + * A filter to be used in the Spark History Server for redirecting YARN proxy requests to the + * main SHS address. This is useful for applications that are using the history server as the + * tracking URL, since the SHS-generated pages cannot be rendered in that case without extra + * configuration to set up a proxy base URI (meaning the SHS cannot be ever used directly). + */ +class YarnProxyRedirectFilter extends Filter with Logging { + + import YarnProxyRedirectFilter._ + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + + // The YARN proxy will send a request with the "proxy-user" cookie set to the YARN's client + // user name. We don't expect any other clients to set this cookie, since the SHS does not + // use cookies for anything. + Option(hreq.getCookies()).flatMap(_.find(_.getName() == COOKIE_NAME)) match { + case Some(_) => + doRedirect(hreq, res.asInstanceOf[HttpServletResponse]) + + case _ => + chain.doFilter(req, res) + } + } + + private def doRedirect(req: HttpServletRequest, res: HttpServletResponse): Unit = { + val redirect = req.getRequestURL().toString() + + // Need a client-side redirect instead of an HTTP one, otherwise the YARN proxy itself + // will handle the redirect and get into an infinite loop. + val content = s""" + | + | + | Spark History Server Redirect + | + | + | + |

The requested page can be found at: $redirect.

+ | + | + """.stripMargin + + logDebug(s"Redirecting YARN proxy request to $redirect.") + res.setStatus(HttpServletResponse.SC_OK) + res.setContentType("text/html") + res.getWriter().write(content) + } + +} + +private[spark] object YarnProxyRedirectFilter { + val COOKIE_NAME = "proxy-user" +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 163dfb5a605c..53fb467f6408 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,7 +55,7 @@ private[spark] class YarnRMClient extends Logging { driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, - uiAddress: String, + uiAddress: Option[String], uiHistoryAddress: String, securityMgr: SecurityManager, localResources: Map[String, LocalResource] @@ -65,9 +65,13 @@ private[spark] class YarnRMClient extends Logging { amClient.start() this.uiHistoryAddress = uiHistoryAddress + val trackingUrl = uiAddress.getOrElse { + if (sparkConf.get(ALLOW_HISTORY_SERVER_TRACKING_URL)) uiHistoryAddress else "" + } + logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index f19a5b22a757..d8c96c35ca71 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -82,6 +82,13 @@ package object config { .stringConf .createOptional + private[spark] val ALLOW_HISTORY_SERVER_TRACKING_URL = + ConfigBuilder("spark.yarn.historyServer.allowTracking") + .doc("Allow using the History Server URL for the application as the tracking URL for the " + + "application when the Web UI is not enabled.") + .booleanConf + .createWithDefault(false) + /* File distribution. */ private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 5df4fbd9c153..2fdb70a73c75 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -55,7 +55,7 @@ private[spark] class CredentialUpdater( /** Start the credential updater task */ def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) val remainingTime = startTime - System.currentTimeMillis() if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala new file mode 100644 index 000000000000..54dbe9d50a68 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{PrintWriter, StringWriter} +import javax.servlet.FilterChain +import javax.servlet.http.{Cookie, HttpServletRequest, HttpServletResponse} + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite + +class YarnProxyRedirectFilterSuite extends SparkFunSuite { + + test("redirect proxied requests, pass-through others") { + val requestURL = "http://example.com:1234/foo?" + val filter = new YarnProxyRedirectFilter() + val cookies = Array(new Cookie(YarnProxyRedirectFilter.COOKIE_NAME, "dr.who")) + + val req = mock(classOf[HttpServletRequest]) + + // First request mocks a YARN proxy request (with the cookie set), second one has no cookies. + when(req.getCookies()).thenReturn(cookies, null) + when(req.getRequestURL()).thenReturn(new StringBuffer(requestURL)) + + val res = mock(classOf[HttpServletResponse]) + when(res.getWriter()).thenReturn(new PrintWriter(new StringWriter())) + + val chain = mock(classOf[FilterChain]) + + // First request is proxied. + filter.doFilter(req, res, chain) + verify(chain, never()).doFilter(req, res) + + // Second request is not, so should invoke the filter chain. + filter.doFilter(req, res, chain) + verify(chain, times(1)).doFilter(req, res) + } + +} diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d8cd68e2d9e9..59f93b3c469d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -123,7 +123,8 @@ statement | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction - | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? @@ -693,7 +694,7 @@ nonReserved | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN + | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | COST | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF | SET | RESET | VIEW | REPLACE @@ -794,6 +795,7 @@ EXPLAIN: 'EXPLAIN'; FORMAT: 'FORMAT'; LOGICAL: 'LOGICAL'; CODEGEN: 'CODEGEN'; +COST: 'COST'; CAST: 'CAST'; SHOW: 'SHOW'; TABLES: 'TABLES'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index afea4676893e..791e8d80e6cb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -117,7 +117,7 @@ public void setNullShort(int ordinal) { public void setNullInt(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0); + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } public void setNullLong(int ordinal) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8b53d988cbc5..e9d9508e5adf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -117,11 +117,10 @@ object JavaTypeInference { val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) - case _ => + case other => // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. - val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(returnType) @@ -131,10 +130,15 @@ object JavaTypeInference { } } - private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors - .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + .filter(_.getReadMethod != null) + } + + private def getJavaBeanReadableAndWritableProperties( + beanClass: Class[_]): Array[PropertyDescriptor] = { + getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null) } private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { @@ -298,9 +302,7 @@ object JavaTypeInference { keyData :: valueData :: Nil) case other => - val properties = getJavaBeanProperties(other) - assert(properties.length > 0) - + val properties = getJavaBeanReadableAndWritableProperties(other) val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType @@ -417,21 +419,16 @@ object JavaTypeInference { ) case other => - val properties = getJavaBeanProperties(other) - if (properties.length > 0) { - CreateNamedStruct(properties.flatMap { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val fieldValue = Invoke( - inputObject, - p.getReadMethod.getName, - inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil - }) - } else { - throw new UnsupportedOperationException( - s"Cannot infer type for class ${other.getName} because it is not bean-compliant") - } + val properties = getJavaBeanReadableAndWritableProperties(other) + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil + }) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd517a98aca1..6d569b612de7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: - ResolveInlineTables :: + ResolveInlineTables(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -180,12 +180,8 @@ class Analyzer( def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan transformDown { case u : UnresolvedRelation => - val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) - .map(_._2).map { relation => - val withAlias = u.alias.map(SubqueryAlias(_, relation, None)) - withAlias.getOrElse(relation) - } - substituted.getOrElse(u) + cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) + .map(_._2).getOrElse(u) case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { @@ -226,6 +222,7 @@ class Analyzer( expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() @@ -623,7 +620,7 @@ class Analyzer( val tableIdentWithDb = u.tableIdentifier.copy( database = u.tableIdentifier.database.orElse(defaultDatabase)) try { - catalog.lookupRelation(tableIdentWithDb, u.alias) + catalog.lookupRelation(tableIdentWithDb) } catch { case _: NoSuchTableException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") @@ -1669,7 +1666,6 @@ class Analyzer( var resolvedGenerator: Generate = null val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => // It's a sanity check, this should not happen as the previous case will throw // exception earlier. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 36ab8b8527b4..7529f9028498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.SimpleCatalogRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 2124177461b3..70438eb5912b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -54,10 +54,8 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case r: UnresolvedRelation => - val alias = r.alias.getOrElse(r.tableIdentifier.table) - if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan - + case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => + BroadcastHint(plan) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => BroadcastHint(plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 7323197b10f6..d5b3ea8c37c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -object ResolveInlineTables extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] { InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { - if (e.dataType.sameType(targetType)) { - e.eval() + val castedExpr = if (e.dataType.sameType(targetType)) { + e } else { - Cast(e, targetType).eval() + Cast(e, targetType) } + castedExpr.transform { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + }.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 07b3558ee2f5..397f5cfe2a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -75,7 +75,7 @@ object UnsupportedOperationChecker { if (watermarkAttributes.isEmpty) { throwError( s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + s"streaming DataFrames/DataSets without watermark")(plan) } case InternalOutputModes.Complete if aggregates.isEmpty => @@ -120,6 +120,10 @@ object UnsupportedOperationChecker { throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + "streaming DataFrame/Dataset") + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => + throwError("dropDuplicates is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 36ed9ba50372..262b894e2a0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -37,10 +37,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str /** * Holds the name of a relation that has yet to be looked up in a catalog. */ -case class UnresolvedRelation( - tableIdentifier: TableIdentifier, - alias: Option[String] = None) extends LeafNode { - +case class UnresolvedRelation(tableIdentifier: TableIdentifier) extends LeafNode { /** Returns a `.` separated name for this relation. */ def tableName: String = tableIdentifier.unquotedString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 5233699facae..31eded4deba7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -125,7 +125,6 @@ abstract class ExternalCatalog { table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit /** @@ -140,7 +139,6 @@ abstract class ExternalCatalog { loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit @@ -150,8 +148,7 @@ abstract class ExternalCatalog { loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit + numDP: Int): Unit // -------------------------------------------------------------------------- // Partitions @@ -247,11 +244,13 @@ abstract class ExternalCatalog { * @param db database name * @param table table name * @param predicates partition-pruning predicates + * @param defaultTimeZoneId default timezone id to parse partition values of TimestampType */ def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] // -------------------------------------------------------------------------- // Functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 15aed5f9b1bd..340e8451f14e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -325,7 +325,6 @@ class InMemoryCatalog( table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadTable is not implemented") } @@ -336,7 +335,6 @@ class InMemoryCatalog( loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadPartition is not implemented.") @@ -348,8 +346,7 @@ class InMemoryCatalog( loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = { + numDP: Int): Unit = { throw new UnsupportedOperationException("loadDynamicPartitions is not implemented.") } @@ -547,7 +544,8 @@ class InMemoryCatalog( override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { // TODO: Provide an implementation throw new UnsupportedOperationException( "listPartitionsByFilter is not implemented for InMemoryCatalog") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index dd0c5cb7066f..f6412e42c13d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -322,13 +322,12 @@ class SessionCatalog( name: TableIdentifier, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime, isSrcLocal) + externalCatalog.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) } /** @@ -341,7 +340,6 @@ class SessionCatalog( loadPath: String, spec: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) @@ -350,7 +348,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(table, Some(db))) requireNonEmptyValueInPartitionSpec(Seq(spec)) externalCatalog.loadPartition( - db, table, loadPath, spec, isOverwrite, holdDDLTime, inheritTableSpecs, isSrcLocal) + db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } def defaultTablePath(tableIdent: TableIdentifier): String = { @@ -572,16 +570,14 @@ class SessionCatalog( * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. * * @param name The name of the table/view that we look up. - * @param alias The alias name of the table/view that we look up. */ - def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { + def lookupRelation(name: TableIdentifier): LogicalPlan = { synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - val relationAlias = alias.getOrElse(table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(relationAlias, viewDef, None) + SubqueryAlias(table, viewDef, None) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -594,12 +590,17 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(relationAlias, child, Some(name.copy(table = table, database = Some(db)))) + SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) } else { - SubqueryAlias(relationAlias, SimpleCatalogRelation(metadata), None) + val tableRelation = CatalogRelation( + metadata, + // we assume all the columns are nullable. + metadata.dataSchema.asNullable.toAttributes, + metadata.partitionSchema.asNullable.toAttributes) + SubqueryAlias(table, tableRelation, None) } } else { - SubqueryAlias(relationAlias, tempTables(table), None) + SubqueryAlias(table, tempTables(table), None) } } } @@ -840,7 +841,7 @@ class SessionCatalog( val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) - externalCatalog.listPartitionsByFilter(db, table, predicates) + externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 2b3b575b4c06..887caf07d148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.catalyst.catalog import java.util.Date -import scala.collection.mutable +import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.StructType @@ -112,11 +113,11 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) + val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - // TODO: use correct timezone for partition values. - Cast(Literal(spec(field.name)), field.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() }) } } @@ -349,36 +350,43 @@ object CatalogTypes { /** - * An interface that is implemented by logical plans to return the underlying catalog table. - * If we can in the future consolidate SimpleCatalogRelation and MetastoreRelation, we should - * probably remove this interface. + * A [[LogicalPlan]] that represents a table. */ -trait CatalogRelation { - def catalogTable: CatalogTable - def output: Seq[Attribute] -} +case class CatalogRelation( + tableMeta: CatalogTable, + dataCols: Seq[Attribute], + partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + assert(tableMeta.identifier.database.isDefined) + assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) + assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) + + // The partition column should always appear after data columns. + override def output: Seq[Attribute] = dataCols ++ partitionCols + + def isPartitioned: Boolean = partitionCols.nonEmpty + + override def equals(relation: Any): Boolean = relation match { + case other: CatalogRelation => tableMeta == other.tableMeta && output == other.output + case _ => false + } + override def hashCode(): Int = { + Objects.hashCode(tableMeta.identifier, output) + } -/** - * A [[LogicalPlan]] that wraps [[CatalogTable]]. - * - * Note that in the future we should consolidate this and HiveCatalogRelation. - */ -case class SimpleCatalogRelation( - metadata: CatalogTable) - extends LeafNode with CatalogRelation { - - override def catalogTable: CatalogTable = metadata - - override lazy val resolved: Boolean = false - - override val output: Seq[Attribute] = { - val (partCols, dataCols) = metadata.schema.toAttributes - // Since data can be dumped in randomly with no validation, everything is nullable. - .map(_.withNullability(true).withQualifier(Some(metadata.identifier.table))) - .partition { a => - metadata.partitionColumnNames.contains(a.name) - } - dataCols ++ partCols + /** Only compare table identifier. */ + override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + + override def computeStats(conf: CatalystConf): Statistics = { + // For data source tables, we will create a `LogicalRelation` and won't call this method, for + // hive serde tables, we will always generate a statistics. + // TODO: unify the table stats generation. + tableMeta.stats.map(_.toPlanStats(output)).getOrElse { + throw new IllegalStateException("table stats must be specified.") + } } + + override def newInstance(): LogicalPlan = copy( + dataCols = dataCols.map(_.newInstance()), + partitionCols = partitionCols.map(_.newInstance())) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 66e52ca68af1..c062e4e84bcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -280,11 +280,10 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore - def table(ref: String): LogicalPlan = - UnresolvedRelation(TableIdentifier(ref), None) + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) def table(db: String, ref: String): LogicalPlan = - UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + UnresolvedRelation(TableIdentifier(ref, Option(db))) implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { def select(exprs: Expression*): LogicalPlan = { @@ -369,16 +368,13 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def as(alias: String): LogicalPlan = logicalPlan match { - case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) - case plan => SubqueryAlias(alias, plan, None) - } + def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None) def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) - def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan = - RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n)) + def distribute(exprs: Expression*)(n: Int): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan, numPartitions = n) def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 1504a522798b..9f4a0f2b7017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -28,7 +28,7 @@ object AttributeMap { } } -class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) +class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6b7cf7991d39..8433a93ea303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +61,7 @@ case class Percentile( frequencyExpression : Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, Literal(1L), 0, 0) @@ -130,15 +130,20 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } override def update( - buffer: OpenHashMap[Number, Long], - input: InternalRow): OpenHashMap[Number, Long] = { - val key = child.eval(input).asInstanceOf[Number] + buffer: OpenHashMap[AnyRef, Long], + input: InternalRow): OpenHashMap[AnyRef, Long] = { + val key = child.eval(input).asInstanceOf[AnyRef] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. @@ -155,32 +160,32 @@ case class Percentile( } override def merge( - buffer: OpenHashMap[Number, Long], - other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = { + buffer: OpenHashMap[AnyRef, Long], + other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } buffer } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -200,7 +205,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -213,18 +218,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -238,7 +242,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -261,11 +265,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -274,7 +278,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1b98c30d3760..e84796f2edad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -173,7 +173,6 @@ case class Stack(children: Seq[Expression]) extends Generator { } } - /** * Only support code generation when stack produces 50 rows or less. */ @@ -204,6 +203,10 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Wrapper around another generator to specify outer behavior. This is used to implement functions + * such as explode_outer. This expression gets replaced during analysis. + */ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generator { final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") @@ -212,7 +215,10 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") override def elementSchema: StructType = child.elementSchema + + override lazy val resolved: Boolean = false } + /** * A base class for [[Explode]] and [[PosExplode]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index e14f0544c2b8..2d9c2e42064b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction { } } - /** - * Simulates Hive's hashing function at - * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * Simulates Hive's hashing function from Hive v1.2.1 at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() * * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution @@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def hasherClassName: String = classOf[HiveHasher].getName override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, seed).toInt + HiveHashFunction.hash(value, dataType, this.seed).toInt } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = struct.numFields while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt i += 1 } result - case _ => super.hash(value, dataType, seed) + case _ => super.hash(value, dataType, 0) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 26697e9867b3..a3cc4529b545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -63,7 +63,9 @@ case class TableIdentifier(table: String, database: Option[String]) } /** A fully qualified identifier for a table (i.e., database.tableName) */ -case class QualifiedTableName(database: String, name: String) +case class QualifiedTableName(database: String, name: String) { + override def toString: String = s"$database.$name" +} object TableIdentifier { def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 995095969d7a..9b80c0fc87c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -58,7 +58,10 @@ class JacksonParser( private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType)) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } @transient private[this] var isWarningPrinted: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0c13e3e93a42..036da3ad2062 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), - RewriteDistinctAggregates) :: + RewriteDistinctAggregates, + ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -578,7 +579,7 @@ object CollapseRepartition extends Rule[LogicalPlan] { RepartitionByExpression(exprs, child, numPartitions) // Case 3 case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = Some(numPartitions)) + r.copy(numPartitions = numPartitions) // Case 3 case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => RepartitionByExpression(exprs, child, numPartitions) @@ -1142,6 +1143,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. + */ +object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Deduplicate(keys, child, streaming) if !streaming => + val keyExprIds = keys.map(_.exprId) + val aggCols = child.output.map { attr => + if (keyExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) + } + } + Aggregate(keys, aggCols, child) + } +} + /** * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. * {{{ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4f593c894acd..21d1cd593262 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -457,7 +457,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // join is not always picked from its children, but can also be null. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) => + case j @ Join(_, _, Inner, _) if !stop => j.transformExpressions(replaceFoldable) // We can fold the projections an expand holds. However expand changes the output columns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bbb9922c187d..d2e091f4dda6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visit(ctx.dataType).asInstanceOf[DataType] + visitSparkDataType(ctx.dataType) } /* ******************************************************************************************** @@ -179,7 +179,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } InsertIntoTable( - UnresolvedRelation(tableIdent, None), + UnresolvedRelation(tableIdent), partitionKeys, query, ctx.OVERWRITE != null, @@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Sort(sort.asScala.map(visitSortItem), global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) + withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... DISTRIBUTE BY ... Sort( sort.asScala.map(visitSortItem), global = false, - RepartitionByExpression(expressionList(distributeBy), query)) + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { // CLUSTER BY ... val expressions = expressionList(clusterBy) Sort( expressions.map(SortOrder(_, Ascending)), global = false, - RepartitionByExpression(expressions, query)) + withRepartitionByExpression(ctx, expressions, query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // [EMPTY] query @@ -273,6 +273,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + /** * Create a logical plan using a query specification. */ @@ -645,17 +655,21 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { - UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) } /** * Create an aliased table reference. This is typically used in FROM clauses. */ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { - val table = UnresolvedRelation( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.strictIdentifier).map(_.getText)) - table.optionalMap(ctx.sample)(withSample) + val table = UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) + + val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { + case Some(strictIdentifier) => + SubqueryAlias(strictIdentifier, table, None) + case _ => table + } + tableWithAlias.optionalMap(ctx.sample)(withSample) } /** @@ -992,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[Cast]] expression. */ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } /** @@ -1410,6 +1424,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /* ******************************************************************************************** * DataType parsing * ******************************************************************************************** */ + /** + * Create a Spark DataType. + */ + private def visitSparkDataType(ctx: DataTypeContext): DataType = { + HiveStringType.replaceCharType(typedVisit(ctx)) + } + /** * Resolve/create a primitive type. */ @@ -1424,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("double", Nil) => DoubleType case ("date", Nil) => DateType case ("timestamp", Nil) => TimestampType - case ("char" | "varchar" | "string", Nil) => StringType - case ("char" | "varchar", _ :: Nil) => StringType + case ("string", Nil) => StringType + case ("char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) case ("binary", Nil) => BinaryType case ("decimal", Nil) => DecimalType.USER_DEFAULT case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) @@ -1447,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.MAP => MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) case SqlBaseParser.STRUCT => - createStructType(ctx.complexColTypeList()) + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) } } @@ -1466,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a [[StructField]] from a column definition. + * Create a top level [[StructField]] from a column definition. */ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { import ctx._ @@ -1477,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { builder.putString("comment", string(STRING)) } // Add Hive type string to metadata. - dataType match { - case p: PrimitiveDataTypeContext => - p.identifier.getText.toLowerCase match { - case "varchar" | "char" => - builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase) - case _ => - } - case _ => + val rawDataType = typedVisit[DataType](ctx.dataType) + val cleanedDataType = HiveStringType.replaceCharType(rawDataType) + if (rawDataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) } StructField( identifier.getText, - typedVisit(dataType), + cleanedDataType, nullable = true, builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0937825e273a..e22b429aec68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -115,6 +115,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) } + override def verboseStringWithSuffix: String = { + super.verboseString + statsCache.map(", " + _.toString).getOrElse("") + } + /** * Returns the maximum number of rows that this plan may compute. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 91404d4bb81b..f24b240956a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.math.{MathContext, RoundingMode} + import scala.util.control.NonFatal import org.apache.spark.internal.Logging @@ -24,6 +26,7 @@ import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -54,8 +57,13 @@ case class Statistics( /** Readable string representation for the Statistics. */ def simpleString: String = { - Seq(s"sizeInBytes=$sizeInBytes", - if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", + Seq(s"sizeInBytes=${Utils.bytesToString(sizeInBytes)}", + if (rowCount.isDefined) { + // Show row count in scientific notation. + s"rowCount=${BigDecimal(rowCount.get, new MathContext(3, RoundingMode.HALF_UP)).toString()}" + } else { + "" + }, s"isBroadcastable=$isBroadcastable" ).filter(_.nonEmpty).mkString(", ") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index af5763251679..ccebae3cc270 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -129,6 +129,14 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } + + override def computeStats(conf: CatalystConf): Statistics = { + if (conf.cboEnabled) { + FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + } else { + super.computeStats(conf) + } + } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -844,18 +852,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * information about the number of partitions during execution. Used when a specific ordering or * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. - * If `numPartitions` is not specified, the number of partitions will be the number set by - * `spark.sql.shuffle.partitions`. */ case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Option[Int] = None) extends UnaryNode { + numPartitions: Int) extends UnaryNode { - numPartitions match { - case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") - case None => // Ok - } + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output @@ -869,3 +872,12 @@ case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) } + +/** A logical plan for `dropDuplicates`. */ +case class Deduplicate( + keys: Seq[Attribute], + child: LogicalPlan, + streaming: Boolean) extends UnaryNode { + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala new file mode 100644 index 000000000000..0c928832d7d2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import scala.collection.immutable.HashSet +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + + private val childStats = plan.child.stats(catalystConf) + + /** + * We will update the corresponding ColumnStats for a column after we apply a predicate condition. + * For example, column c has [min, max] value as [0, 100]. In a range condition such as + * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we + * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. + */ + private val colStatsMap = new ColumnStatsMap + + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ + def estimate: Option[Statistics] = { + if (childStats.rowCount.isEmpty) return None + + // save a mutable copy of colStats so that we can later change it recursively + colStatsMap.setInitValues(childStats.attributeStats) + + // estimate selectivity of this filter predicate + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { + case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => 1.0 + } + + val newColStats = colStatsMap.toColumnStats + + val filteredRowCount: BigInt = + EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes: BigInt = + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) + + Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + attributeStats = newColStats)) + } + + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + condition match { + case And(cond1, cond2) => + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update) + val percent2 = calculateFilterSelectivity(cond2, update) + (percent1, percent2) match { + case (Some(p1), Some(p2)) => Some(p1 * p2) + case (Some(p1), None) => Some(p1) + case (None, Some(p2)) => Some(p2) + case (None, None) => None + } + + case Or(cond1, cond2) => + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update = false) + val percent2 = calculateFilterSelectivity(cond2, update = false) + (percent1, percent2) match { + case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) + case (Some(p1), None) => Some(1.0) + case (None, Some(p2)) => Some(1.0) + case (None, None) => None + } + + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => None + } + + case _ => calculateSingleCondition(condition, update) + } + } + + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param condition a single logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + condition match { + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + + // EqualTo/EqualNullSafe does not care about the order + case op @ Equality(ar: Attribute, l: Literal) => + evaluateEquality(ar, l, update) + case op @ Equality(l: Literal, ar: Attribute) => + evaluateEquality(ar, l, update) + + case op @ LessThan(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThan(l: Literal, ar: Attribute) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op @ LessThanOrEqual(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThanOrEqual(l: Literal, ar: Attribute) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op @ GreaterThan(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThan(l: Literal, ar: Attribute) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + case In(ar: Attribute, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ar: Attribute, set) => + evaluateInSet(ar, set, update) + + case IsNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = true, update) + + case IsNotNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = false, update) + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + None + } + } + + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param attr an Attribute (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics collected for a given column. + */ + def evaluateNullCheck( + attr: Attribute, + isNull: Boolean, + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) + val rowCountValue = childStats.rowCount.get + val nullPercent: BigDecimal = if (rowCountValue == 0) { + 0 + } else { + BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + } + + if (update) { + val newStats = if (isNull) { + colStat.copy(distinctCount = 0, min = None, max = None) + } else { + colStat.copy(nullCount = 0) + } + colStatsMap(attr) = newStats + } + + val percent = if (isNull) { + nullPercent.toDouble + } else { + 1.0 - nullPercent.toDouble + } + + Some(percent) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attr an Attribute (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column or wrong value. + */ + def evaluateBinary( + op: BinaryComparison, + attr: Attribute, + literal: Literal, + update: Boolean): Option[Double] = { + attr.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attr, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) + None + case _ => + // TODO: support boolean type. + None + } + } + + /** + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ + def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { + attrDataType match { + case DateType => + Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) + case TimestampType => + Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case _: DecimalType => + Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) + case StringType | BinaryType => + None + case _ => + Some(litValue) + } + } + + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attr an Attribute (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateEquality( + attr: Attribute, + literal: Literal, + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val statsRange = Range(colStat.min, colStat.max, attr.dataType) + if (statsRange.contains(literal)) { + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attr.dataType, literal.value) + val newStats = colStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + colStatsMap(attr) = newStats + } + + Some(1.0 / ndv.toDouble) + } else { + Some(0.0) + } + + } + + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attr an Attribute (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column. + */ + + def evaluateInSet( + attr: Attribute, + hSet: Set[Any], + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + val dataType = attr.dataType + var newNdv = ndv + + // use [min, max] to filter the original hSet + dataType match { + case _: NumericType | BooleanType | DateType | TimestampType => + val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val validQuerySet = hSet.filter { v => + v != null && statsRange.contains(Literal(v, dataType)) + } + + if (validQuerySet.isEmpty) { + return Some(0.0) + } + + // Need to save new min/max using the external type value of the literal + val newMax = convertBoundValue( + attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) + val newMin = convertBoundValue( + attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) + + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = ndv.min(BigInt(validQuerySet.size)) + if (update) { + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + colStatsMap(attr) = newStats + } + + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => + newNdv = ndv.min(BigInt(hSet.size)) + if (update) { + val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + colStatsMap(attr) = newStats + } + } + + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. + Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attr an Attribute (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForNumeric( + op: BinaryComparison, + attr: Attribute, + literal: Literal, + update: Boolean): Option[Double] = { + + var percent = 1.0 + val colStat = colStatsMap(attr) + val statsRange = + Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + + // determine the overlapping degree between predicate range and column's range + val literalValueBD = BigDecimal(literal.value.toString) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case _: LessThan => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case _: LessThanOrEqual => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case _: GreaterThan => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case _: GreaterThanOrEqual => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // this is partial overlap case + val literalDouble = literalValueBD.toDouble + val maxDouble = BigDecimal(statsRange.max).toDouble + val minDouble = BigDecimal(statsRange.min).toDouble + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + // For ease of computation, we convert all relevant numeric values to Double. + percent = op match { + case _: LessThan => + (literalDouble - minDouble) / (maxDouble - minDouble) + case _: LessThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.min)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (literalDouble - minDouble) / (maxDouble - minDouble) + } + case _: GreaterThan => + (maxDouble - literalDouble) / (maxDouble - minDouble) + case _: GreaterThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.max)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (maxDouble - literalDouble) / (maxDouble - minDouble) + } + } + + if (update) { + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attr.dataType, literal.value) + var newMax = colStat.max + var newMin = colStat.min + op match { + case _: GreaterThan => newMin = newValue + case _: GreaterThanOrEqual => newMin = newValue + case _: LessThan => newMax = newValue + case _: LessThanOrEqual => newMax = newValue + } + + val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + colStatsMap(attr) = newStats + } + } + + Some(percent) + } + +} + +class ColumnStatsMap { + private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty + + def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = { + baseMap.clear() + baseMap ++= colStats.baseMap + } + + def contains(a: Attribute): Boolean = baseMap.contains(a.exprId) + + def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2 + + def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats) + + def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 982a5a8bb89b..9782c0bb0a93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging case _ if !rowCountsExist(conf, join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val selectivity = joinSelectivity(joinKeyPairs) @@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (innerJoinedRows == 0) { + } else if (selectivity == 0) { joinType match { - // For outer joins, if the inner join part is empty, the number of output rows is the + // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side // keep unchanged, while column stats of join keys from the other side should be updated // based on added null values. @@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } case _ => Nil } + } else if (selectivity == 1) { + // Cartesian product, just propagate the original column stats + inputAttrStats.toSeq } else { val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { @@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - isBroadcastable = false)) + attributeStats = outputAttrStats)) case _ => // When there is no equi-join condition, we do estimation like cartesian product. @@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), rowCount = Some(outputRows), - attributeStats = inputAttrStats, - isBroadcastable = false)) + attributeStats = inputAttrStats)) } // scalastyle:off @@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } if (ndvDenom < 0) { - // There isn't join keys or column stats for any of the join key pairs, we do estimation like - // cartesian product. + // We can't find any join key pairs with column stats, estimate it as cartesian join. 1 } else if (ndvDenom == 0) { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. @@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging /** * Propagate or update column stats for output attributes. - * 1. For cartesian product, all values are preserved, so there's no need to change column stats. - * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update - * distinct count of other attributes based on output rows after join. */ private def updateAttrStats( outputRows: BigInt, @@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - if (outputRows == leftRows * rightRows) { - // Cartesian product, just propagate the original column stats - attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) - } else { - val leftRatio = - if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) - val rightRatio = - if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) - attributes.foreach { a => - // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats += a -> joinKeyStats(a) + } else { + val leftRatio = if (leftRows != 0) { + BigDecimal(outputRows) / BigDecimal(leftRows) + } else { + BigDecimal(0) + } + val rightRatio = if (rightRows != 0) { + BigDecimal(outputRows) / BigDecimal(rightRows) } else { - val oldColStat = oldAttrStats(a) - val oldNdv = oldColStat.distinctCount - // We only change (scale down) the number of distinct values if the number of rows - // decreases after join, because join won't produce new values even if the number of - // rows increases. - val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { - ceil(BigDecimal(oldNdv) * leftRatio) - } else if (join.right.outputSet.contains(a) && rightRatio < 1) { - ceil(BigDecimal(oldNdv) * rightRatio) - } else { - oldNdv - } - // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + BigDecimal(0) } + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) } + } outputAttrStats } @@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) - intersectedStats.put(leftKey, - leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) - intersectedStats.put(rightKey, - rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) + val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + intersectedStats.put(leftKey, newStats) + intersectedStats.put(rightKey, newStats) } AttributeMap(intersectedStats.toSeq) } @@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), rowCount = Some(outputRows), - attributeStats = leftStats.attributeStats, - isBroadcastable = false)) + attributeStats = leftStats.attributeStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 5aa6b9353bc4..3d13967cb62a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -20,24 +20,39 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} /** Value range of a column. */ -trait Range +trait Range { + def contains(l: Literal): Boolean +} /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range +case class NumericRange(min: JDecimal, max: JDecimal) extends Range { + override def contains(l: Literal): Boolean = { + val decimal = l.dataType match { + case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + case _ => new JDecimal(l.value.toString) + } + min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + } +} /** * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range +class DefaultRange extends Range { + override def contains(l: Literal): Boolean = true +} /** This is for columns with only null values. */ -class NullRange extends Range +class NullRange extends Range { + override def contains(l: Literal): Boolean = false +} object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { @@ -113,4 +128,5 @@ object Range { DateTimeUtils.toJavaTimestamp(n.max.longValue())) } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f37661c31584..cc4c0835954b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -453,13 +453,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** ONE line description of this node with more information */ def verboseString: String + /** ONE line description of this node with some suffix information */ + def verboseStringWithSuffix: String = verboseString + override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ def treeString: String = treeString(verbose = true) - def treeString(verbose: Boolean): String = { - generateTreeString(0, Nil, new StringBuilder, verbose).toString + def treeString(verbose: Boolean, addSuffix: Boolean = false): String = { + generateTreeString(0, Nil, new StringBuilder, verbose = verbose, addSuffix = addSuffix).toString } /** @@ -524,7 +527,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { if (depth > 0) { lastChildren.init.foreach { isLast => @@ -533,22 +537,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(if (lastChildren.last) "+- " else ":- ") } + val str = if (verbose) { + if (addSuffix) verboseStringWithSuffix else verboseString + } else { + simpleString + } builder.append(prefix) - builder.append(if (verbose) verboseString else simpleString) + builder.append(str) builder.append("\n") if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose)) + depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose, + addSuffix = addSuffix)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose) + depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose, + addSuffix = addSuffix) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, builder, verbose, prefix)) + depth + 1, lastChildren :+ false, builder, verbose, prefix, addSuffix)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, builder, verbose, prefix) + depth + 1, lastChildren :+ true, builder, verbose, prefix, addSuffix) } builder diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala new file mode 100644 index 000000000000..b319eb70bc13 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.unsafe.types.UTF8String + +/** + * A hive string type for compatibility. These datatypes should only used for parsing, + * and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. + */ +sealed abstract class HiveStringType extends AtomicType { + private[sql] type InternalType = UTF8String + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { + typeTag[InternalType] + } + + override def defaultSize: Int = length + + private[spark] override def asNullable: HiveStringType = this + + def length: Int +} + +object HiveStringType { + def replaceCharType(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharType(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharType(kt), replaceCharType(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharType(field.dataType)) + }) + case _: HiveStringType => StringType + case _ => dt + } +} + +/** + * Hive char type. + */ +case class CharType(length: Int) extends HiveStringType { + override def simpleString: String = s"char($length)" +} + +/** + * Hive varchar type. + */ +case class VarcharType(length: Int) extends HiveStringType { + override def simpleString: String = s"varchar($length)" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 81a97dc1ff3f..01737e0a1734 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,11 +21,12 @@ import java.util.TimeZone import org.scalatest.ShouldMatchers -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, Inner} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -61,23 +62,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation)) assertAnalysisError( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Seq("cannot resolve")) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation), caseSensitive = false) } @@ -166,12 +167,12 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("resolve relations") { - assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq()) - checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) + assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe")), Seq()) + checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE")), testRelation) checkAnalysis( - UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) + UnresolvedRelation(TableIdentifier("tAbLe")), testRelation, caseSensitive = false) checkAnalysis( - UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) + UnresolvedRelation(TableIdentifier("TaBlE")), testRelation, caseSensitive = false) } test("divide should be casted into fractional types") { @@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("pull out nondeterministic expressions from RepartitionByExpression") { - val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10) val projected = Alias(Rand(33), "_nondeterministic")() val expected = Project(testRelation.output, RepartitionByExpression(Seq(projected.toAttribute), - Project(testRelation.output :+ projected, testRelation))) + Project(testRelation.output :+ projected, testRelation), + numPartitions = 10)) checkAnalysis(plan, expected) } @@ -429,4 +431,14 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { assertAnalysisSuccess(r1) assertAnalysisSuccess(r2) } + + test("resolve as with an already existed alias") { + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl2.a")), + SubqueryAlias("tbl", testRelation, None).as("tbl2")), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index 920c6ea50f4b..f45a82686984 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types.{LongType, NullType} +import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in * end-to-end tests (in sql/core module) for verifying the correct error messages are shown * in negative cases. */ -class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { +class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) - assert(ResolveInlineTables(table) == table) + assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted = ResolveInlineTables.convert(table) + val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) @@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { assert(converted.data(1).getLong(0) == 2L) } + test("convert TimeZoneAwareExpression") { + val table = UnresolvedInlineTable(Seq("c1"), + Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) + val converted = ResolveInlineTables(conf).convert(table) + val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) + .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] + assert(converted.output.map(_.dataType) == Seq(TimestampType)) + assert(converted.data.size == 1) + assert(converted.data(0).getLong(0) == correct) + } + test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted1 = ResolveInlineTables.convert(table1) + val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) - val converted2 = ResolveInlineTables.convert(table2) + val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3b756e89d903..82be69a0f7d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -36,6 +37,11 @@ case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { val attribute = AttributeReference("a", IntegerType, nullable = true)() + val watermarkMetadata = new MetadataBuilder() + .withMetadata(attribute.metadata) + .putLong(EventTimeWatermark.delayKey, 1000L) + .build() + val attributeWithWatermark = attribute.withMetadata(watermarkMetadata) val batchRelation = LocalRelation(attribute) val streamRelation = new TestStreamingRelation(attribute) @@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in update mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Update) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in complete mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Complete) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations with watermark in append mode", + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "aggregate - streaming aggregations without watermark in append mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + // Aggregation: Distinct aggregates not supported on streaming relation val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) assertSupportedInStreamingPlan( @@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), + outputMode = Append + ) + + // Deduplicate + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation before aggregation", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + Deduplicate(Seq(att), streamRelation, streaming = true)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation after aggregation", + Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true), + outputMode = Complete, + expectedMsgs = Seq("dropDuplicates")) + + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on batch relation inside a streaming query", + Deduplicate(Seq(att), batchRelation, streaming = false), + outputMode = Append + ) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index db73f03c8bb7..a755231962be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -433,37 +433,15 @@ class SessionCatalogSuite extends PlanTest { sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") // If we explicitly specify the database, we'll look up the relation in that database - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))) - == SubqueryAlias("tbl1", SimpleCatalogRelation(metastoreTable1), None)) + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) == SubqueryAlias("tbl1", tempTable1, None)) // Then, if that does not exist, look up the relation in the current database sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", SimpleCatalogRelation(metastoreTable1), None)) - } - - test("lookup table relation with alias") { - val catalog = new SessionCatalog(newBasicCatalog()) - val alias = "monster" - val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) - val relation = SubqueryAlias("tbl1", SimpleCatalogRelation(tableMetadata), None) - val relationWithAlias = - SubqueryAlias(alias, - SimpleCatalogRelation(tableMetadata), None) - assert(catalog.lookupRelation( - TableIdentifier("tbl1", Some("db2")), alias = None) == relation) - assert(catalog.lookupRelation( - TableIdentifier("tbl1", Some("db2")), alias = Some(alias)) == relationWithAlias) - } - - test("lookup view with view name in alias") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tmpView = Range(1, 10, 2, 10) - catalog.createTempView("vw1", tmpView, overrideIfExists = false) - val plan = catalog.lookupRelation(TableIdentifier("vw1"), Option("range")) - assert(plan == SubqueryAlias("range", tmpView, None)) + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) } test("look up view relation") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 032629265269..0cb3a79eee67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + val random = new scala.util.Random test("md5") { checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), @@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { + // Note : All expected hashes need to be computed using Hive 1.2.1 + val actual = HiveHashFunction.hash(input, dataType, seed = 0) + + withClue(s"hash mismatch for input = `$input` of type `$dataType`.") { + assert(actual == expected) + } + } + + def checkHiveHashForIntegralType(dataType: DataType): Unit = { + // corner cases + checkHiveHash(null, dataType, 0) + checkHiveHash(1, dataType, 1) + checkHiveHash(0, dataType, 0) + checkHiveHash(-1, dataType, -1) + checkHiveHash(Int.MaxValue, dataType, Int.MaxValue) + checkHiveHash(Int.MinValue, dataType, Int.MinValue) + + // random values + for (_ <- 0 until 10) { + val input = random.nextInt() + checkHiveHash(input, dataType, input) + } + } + + test("hive-hash for null") { + checkHiveHash(null, NullType, 0) + } + + test("hive-hash for boolean") { + checkHiveHash(true, BooleanType, 1) + checkHiveHash(false, BooleanType, 0) + } + + test("hive-hash for byte") { + checkHiveHashForIntegralType(ByteType) + } + + test("hive-hash for short") { + checkHiveHashForIntegralType(ShortType) + } + + test("hive-hash for int") { + checkHiveHashForIntegralType(IntegerType) + } + + test("hive-hash for long") { + checkHiveHash(1L, LongType, 1L) + checkHiveHash(0L, LongType, 0L) + checkHiveHash(-1L, LongType, 0L) + checkHiveHash(Long.MaxValue, LongType, -2147483648) + // Hive's fails to parse this.. but the hashing function itself can handle this input + checkHiveHash(Long.MinValue, LongType, -2147483648) + + for (_ <- 0 until 10) { + val input = random.nextLong() + checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt) + } + } + + test("hive-hash for float") { + checkHiveHash(0F, FloatType, 0) + checkHiveHash(0.0F, FloatType, 0) + checkHiveHash(1.1F, FloatType, 1066192077L) + checkHiveHash(-1.1F, FloatType, -1081291571) + checkHiveHash(99999999.99999999999F, FloatType, 1287568416L) + checkHiveHash(Float.MaxValue, FloatType, 2139095039) + checkHiveHash(Float.MinValue, FloatType, -8388609) + } + + test("hive-hash for double") { + checkHiveHash(0, DoubleType, 0) + checkHiveHash(0.0, DoubleType, 0) + checkHiveHash(1.1, DoubleType, -1503133693) + checkHiveHash(-1.1, DoubleType, 644349955) + checkHiveHash(1000000000.000001, DoubleType, 1104006509) + checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501) + checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676) + checkHiveHash(Double.MaxValue, DoubleType, -2146435072) + checkHiveHash(Double.MinValue, DoubleType, 1048576) + } + + test("hive-hash for string") { + checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L) + checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L) + checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L) + checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L) + // scalastyle:off nonascii + checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L) + checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L) + // scalastyle:on nonascii + } + + test("hive-hash for array") { + // empty array + checkHiveHash( + input = new GenericArrayData(Array[Int]()), + dataType = ArrayType(IntegerType, containsNull = false), + expected = 0) + + // basic case + checkHiveHash( + input = new GenericArrayData(Array(1, 10000, Int.MaxValue)), + dataType = ArrayType(IntegerType, containsNull = false), + expected = -2147172688L) + + // with negative values + checkHiveHash( + input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)), + dataType = ArrayType(LongType, containsNull = false), + expected = -2147452680L) + + // with nulls only + val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true) + checkHiveHash( + input = new GenericArrayData(Array(null, null)), + dataType = arrayTypeWithNull, + expected = 0) + + // mix with null + checkHiveHash( + input = new GenericArrayData(Array(-12221, 89, null, 767)), + dataType = arrayTypeWithNull, + expected = -363989515) + + // nested with array + checkHiveHash( + input = new GenericArrayData( + Array( + new GenericArrayData(Array(1234L, -9L, 67L)), + new GenericArrayData(Array(null, null)), + new GenericArrayData(Array(55L, -100L, -2147452680L)) + )), + dataType = ArrayType(ArrayType(LongType)), + expected = -1007531064) + + // nested with map + checkHiveHash( + input = new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + )), + dataType = ArrayType(MapType(IntegerType, StringType)), + expected = 1139205955) + } + + test("hive-hash for map") { + val mapType = MapType(IntegerType, StringType) + + // empty map + checkHiveHash( + input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())), + dataType = mapType, + expected = 0) + + // basic case + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))), + dataType = mapType, + expected = 198872) + + // with null value + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(55, -99)), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))), + dataType = mapType, + expected = 1142704473) + + // nesting (only values can be nested as keys have to be primitive datatype) + val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType)) + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, -100)), + new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + ))), + dataType = nestedMapType, + expected = -1142817416) + } + + test("hive-hash for struct") { + // basic + val row = new GenericInternalRow(Array[Any](1, 2, 3)) + checkHiveHash( + input = row, + dataType = + new StructType() + .add("col1", IntegerType) + .add("col2", IntegerType) + .add("col3", IntegerType), + expected = 1026) + + // mix of several datatypes + val structType = new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("arrayOfString", arrayOfString) + .add("mapOfString", mapOfString) + + val rowValues = new ArrayBuffer[Any]() + rowValues += null + rowValues += true + rowValues += 1 + rowValues += 2 + rowValues += Int.MaxValue + rowValues += Long.MinValue + rowValues += new GenericArrayData(Array( + UTF8String.fromString("apache spark"), + UTF8String.fromString("hello world") + )) + rowValues += new ArrayBasedMapData( + new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null)) + ) + + val row2 = new GenericInternalRow(rowValues.toArray) + checkHiveHash( + input = row2, + dataType = structType, + expected = -2119012447) + } + private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 1533fe5f90ee..2420ba513f28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(childExpression, percentageExpression) // Test with rows without frequency - val rows = (1 to count).map( x => Seq(x)) - runTest( agg, rows, expectedPercentiles) + val rows = (1 to count).map(x => Seq(x)) + runTest(agg, rows, expectedPercentiles) // Test with row with frequency. Second and third columns are frequency in Int and Long val countForFrequencyTest = 1000 - val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong) + val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong) val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0) val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false) val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt) - runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) val frequencyExpressionLong = BoundReference(2, LongType, nullable = false) val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong) - runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) // Run test with Flatten data - val flattenRows = (1 to countForFrequencyTest).flatMap( current => - (1 to current).map( y => current )).map( Seq(_)) + val flattenRows = (1 to countForFrequencyTest).flatMap(current => + (1 to current).map(y => current )).map(Seq(_)) runTest(agg, flattenRows, expectedPercentilesWithFrquency) } @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite { } val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType) - for ( dataType <- validDataTypes; + for (dataType <- validDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite { StringType, DateType, TimestampType, CalendarIntervalType, NullType) - for( dataType <- invalidDataTypes; + for(dataType <- invalidDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite { s"'`a`' is of ${dataType.simpleString} type.")) } - for( dataType <- validDataTypes; + for(dataType <- validDataTypes; frequencyType <- invalidFrequencyDataTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite { agg.update(buffer, InternalRow(1, -5)) agg.eval(buffer) } - assert( caught.getMessage.startsWith("Negative values found in ")) + assert(caught.getMessage.startsWith("Negative values found in ")) } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 82756f545a8c..d128315b6886 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -130,6 +130,20 @@ class FoldablePropagationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Propagate in inner join") { + val ta = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('ta) + val tb = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('tb) + val query = ta.join(tb, Inner, + Some("ta.a".attr === "tb.a".attr && "ta.tag".attr === "tb.tag".attr)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + comparePlans(optimized, correctAnswer) + } + test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index f23e262f286b..e68423f85c92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest { Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, - ReplaceIntersectWithSemiJoin) :: Nil + ReplaceIntersectWithSemiJoin, + ReplaceDeduplicateWithAggregate) :: Nil } test("replace Intersect with Left-semi Join") { @@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("replace batch Deduplicate with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val attrB = input.output(1) + val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate( + Seq(attrA), + Seq( + attrA, + Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) + ), + input) + + comparePlans(optimized, correctAnswer) + } + + test("don't replace streaming Deduplicate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + comparePlans(optimized, query) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 2c1425242620..67d5d2202b68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest { val orderSortDistrClusterClauses = Seq( ("", basePlan), (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)()), - (" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc)) + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)) ) orderSortDistrClusterClauses.foreach { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 7d46011b410e..170c469197e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -25,8 +25,8 @@ class TableIdentifierParserSuite extends SparkFunSuite { // Add "$elem$", "$value$" & "$key$" val hiveNonReservedKeyword = Array("add", "admin", "after", "analyze", "archive", "asc", "before", "bucket", "buckets", "cascade", "change", "cluster", "clustered", "clusterstatus", "collection", - "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "data", - "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", + "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "cost", + "data", "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", "dependency", "desc", "directories", "directory", "disable", "distribute", "enable", "escaped", "exclusive", "explain", "export", "fields", "file", "fileformat", "first", "format", "formatted", "functions", "hold_ddltime", "hour", "idxproperties", "ignore", "index", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala new file mode 100644 index 000000000000..8be74ced7bb7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import java.sql.Date + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.types._ + +/** + * In this test suite, we test predicates containing the following operators: + * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN + */ +class FilterEstimationSuite extends StatsEstimationTestBase { + + // Suppose our test table has 10 rows and 6 columns. + // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val arInt = AttributeReference("cint", IntegerType)() + val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + // only 2 values + val arBool = AttributeReference("cbool", BooleanType)() + val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1) + + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + val dMin = Date.valueOf("2017-01-01") + val dMax = Date.valueOf("2017-01-10") + val arDate = AttributeReference("cdate", DateType)() + val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + val decMin = new java.math.BigDecimal("0.200000000000000000") + val decMax = new java.math.BigDecimal("0.800000000000000000") + val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arDouble = AttributeReference("cdouble", DoubleType)() + val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Sixth column cstring has 10 String values: + // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" + val arString = AttributeReference("cstring", StringType)() + val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2) + + test("cint = 2") { + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) + } + + test("cint <=> 2") { + validateEstimatedStats( + arInt, + Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) + } + + test("cint = 0") { + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 0) + } + + test("cint < 3") { + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + 3) + } + + test("cint < 0") { + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 0) + } + + test("cint <= 3") { + validateEstimatedStats( + arInt, + Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + 3) + } + + test("cint > 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 5) + } + + test("cint > 10") { + // This is a corner case since max value is 10. + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 0) + } + + test("cint >= 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 5) + } + + test("cint IS NULL") { + validateEstimatedStats( + arInt, + Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 0, min = None, max = None, + nullCount = 0, avgLen = 4, maxLen = 4), + 0) + } + + test("cint IS NOT NULL") { + validateEstimatedStats( + arInt, + Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 10) + } + + test("cint > 3 AND cint <= 6") { + val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4), + 4) + } + + test("cint = 3 OR cint = 6") { + val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 2) + } + + test("cint IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + 3) + } + + test("cint NOT IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + 7) + } + + test("cbool = true") { + validateEstimatedStats( + arBool, + Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 5) + } + + test("cbool > false") { + // bool comparison is not supported yet, so stats remain same. + validateEstimatedStats( + arBool, + Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 10) + } + + test("cdate = cast('2017-01-02' AS DATE)") { + val d20170102 = Date.valueOf("2017-01-02") + validateEstimatedStats( + arDate, + Filter(EqualTo(arDate, Literal(d20170102)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) + } + + test("cdate < cast('2017-01-03' AS DATE)") { + val d20170103 = Date.valueOf("2017-01-03") + validateEstimatedStats( + arDate, + Filter(LessThan(arDate, Literal(d20170103)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4), + 3) + } + + test("""cdate IN ( cast('2017-01-03' AS DATE), + cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { + val d20170103 = Date.valueOf("2017-01-03") + val d20170104 = Date.valueOf("2017-01-04") + val d20170105 = Date.valueOf("2017-01-05") + validateEstimatedStats( + arDate, + Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4), + 3) + } + + test("cdecimal = 0.400000000000000000") { + val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + validateEstimatedStats( + arDecimal, + Filter(EqualTo(arDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8), + 1) + } + + test("cdecimal < 0.60 ") { + val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + validateEstimatedStats( + arDecimal, + Filter(LessThan(arDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8), + 3) + } + + test("cdouble < 3.0") { + validateEstimatedStats( + arDouble, + Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), + ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8), + 3) + } + + test("cstring = 'A2'") { + validateEstimatedStats( + arString, + Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + 1) + } + + // There is no min/max statistics for String type. We estimate 10 rows returned. + test("cstring < 'A2'") { + validateEstimatedStats( + arString, + Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + 10) + } + + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). + test("cint IN (1, 2, 3, 4, 5)") { + val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildStatsTestplan = StatsTestPlan( + outputList = Seq(arInt), + rowCount = 2L, + attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + ) + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + 2) + } + + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { + StatsTestPlan( + outputList = outList, + rowCount = tableRowCount, + attributeStats = AttributeMap(Seq( + arInt -> childColStatInt, + arBool -> childColStatBool, + arDate -> childColStatDate, + arDecimal -> childColStatDecimal, + arDouble -> childColStatDouble, + arString -> childColStatString + )) + ) + } + + private def validateEstimatedStats( + ar: AttributeReference, + filterNode: Filter, + expectedColStats: ColumnStat, + rowCount: Int): Unit = { + + val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) + val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) + + val filteredStats = filterNode.stats(conf) + assert(filteredStats.sizeInBytes == expectedSizeInBytes) + assert(filteredStats.rowCount.get == rowCount) + assert(filteredStats.attributeStats(ar) == expectedColStats) + + // If the filter has a binary operator (including those nested inside + // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the + // operator, and then check again. + val rewrittenFilter = filterNode transformExpressionsDown { + case EqualTo(ar: AttributeReference, l: Literal) => + EqualTo(l, ar) + + case LessThan(ar: AttributeReference, l: Literal) => + GreaterThan(l, ar) + case LessThanOrEqual(ar: AttributeReference, l: Literal) => + GreaterThanOrEqual(l, ar) + + case GreaterThan(ar: AttributeReference, l: Literal) => + LessThan(l, ar) + case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + LessThanOrEqual(l, ar) + } + + if (rewrittenFilter != filterNode) { + validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index cb9493a57564..41470ae6aae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String /** @@ -263,8 +263,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Loads a JSON file and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -286,8 +286,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. *
    *
  • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
  • + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema. *
  • `DROPMALFORMED` : ignores the whole corrupted records.
  • *
  • `FAILFAST` : throws an exception when it meets corrupted records.
  • *
@@ -323,6 +326,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param jsonRDD input RDD with one JSON object per record * @since 1.4.0 */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) /** @@ -335,7 +339,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param jsonRDD input RDD with one JSON object per record * @since 1.4.0 */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: RDD[String]): DataFrame = { + json(sparkSession.createDataset(jsonRDD)(Encoders.STRING)) + } + + /** + * Loads a `Dataset[String]` storing JSON objects (JSON Lines + * text format or newline-delimited JSON) and returns the result as a `DataFrame`. + * + * Unless the schema is specified using `schema` function, this function goes through the + * input once to determine the input schema. + * + * @param jsonDataset input Dataset with one JSON object per record + * @since 2.2.0 + */ + def json(jsonDataset: Dataset[String]): DataFrame = { val parsedOptions = new JSONOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, @@ -344,12 +363,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val schema = userSpecifiedSchema.getOrElse { JsonInferSchema.infer( - jsonRDD, + jsonDataset.rdd, parsedOptions, createParser) } - val parsed = jsonRDD.mapPartitions { iter => + // Check a field requirement for corrupt records here to throw an exception in a driver side + schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + + val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) } @@ -422,12 +450,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    *
  • + *
  • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 393925161fc7..49e85dc7b13f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -349,8 +349,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { // Get all input data source or hive relations of the query. val srcRelations = df.logicalPlan.collect { case LogicalRelation(src: BaseRelation, _, _) => src - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.catalogTable) => - relation.catalogTable.identifier + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) => + relation.tableMeta.identifier } val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed @@ -360,8 +360,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") // check hive table relation when overwrite mode - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.catalogTable) - && srcRelations.contains(relation.catalogTable.identifier) => + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) + && srcRelations.contains(relation.tableMeta.identifier) => throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") case _ => // OK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 38a24cc8ed8c..1b0462359607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter @@ -175,19 +175,13 @@ class Dataset[T] private[sql]( } @transient private[sql] val logicalPlan: LogicalPlan = { - def hasSideEffects(plan: LogicalPlan): Boolean = plan match { - case _: Command | - _: InsertIntoTable => true - case _ => false - } - + // For various commands (like DDL) and queries with side effects, we force query execution + // to happen right away to let these side effects take place eagerly. queryExecution.analyzed match { - // For various commands (like DDL) and queries with side effects, we force query execution - // to happen right away to let these side effects take place eagerly. - case p if hasSideEffects(p) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) - case Union(children) if children.forall(hasSideEffects) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) + case c: Command => + LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) + case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => + LocalRelation(u.output, queryExecution.executedPlan.executeCollect()) case _ => queryExecution.analyzed } @@ -557,7 +551,8 @@ class Dataset[T] private[sql]( * Spark will use this watermark for several purposes: * - To know when a given time window aggregation can be finalized and thus can be emitted when * using output modes that do not allow updates. - * - To minimize the amount of state that we need to keep for on-going aggregations. + * - To minimize the amount of state that we need to keep for on-going aggregations, + * `mapGroupsWithState` and `dropDuplicates` operators. * * The current watermark is computed by looking at the `MAX(eventTime)` seen across * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost @@ -1981,6 +1976,12 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `distinct`. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -1990,13 +1991,19 @@ class Dataset[T] private[sql]( * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.flatMap { colName => + val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) => // It is possibly there are more than one columns with the same name, // so we call filter instead of find. val cols = allColumns.filter(col => resolver(col.name, colName)) @@ -2006,21 +2013,19 @@ class Dataset[T] private[sql]( } cols } - val groupColExprIds = groupCols.map(_.exprId) - val aggCols = logicalPlan.output.map { attr => - if (groupColExprIds.contains(attr.exprId)) { - attr - } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() - } - } - Aggregate(groupCols, aggCols, logicalPlan) + Deduplicate(groupCols, logicalPlan, isStreaming) } /** * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -2030,6 +2035,12 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -2410,7 +2421,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } /** @@ -2425,7 +2436,8 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + RepartitionByExpression( + partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbe55090ea11..234ef2dffc6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1090,14 +1090,14 @@ object SQLContext { */ private[sql] def beansToRows( data: Iterator[_], - beanInfo: BeanInfo, + beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) } - data.map{ element => + data.map { element => new GenericInternalRow( methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } ): InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1975a56cafe8..afc1827e7eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.Introspector import java.io.Closeable import java.util.concurrent.atomic.AtomicReference @@ -95,18 +94,28 @@ class SparkSession private( /** * State shared across sessions, including the `SparkContext`, cached data, listener, * and a catalog that interacts with external systems. + * + * This is internal to Spark and there is no guarantee on interface stability. + * + * @since 2.2.0 */ + @InterfaceStability.Unstable @transient - private[sql] lazy val sharedState: SharedState = { + lazy val sharedState: SharedState = { existingSharedState.getOrElse(new SharedState(sparkContext)) } /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. + * + * This is internal to Spark and there is no guarantee on interface stability. + * + * @since 2.2.0 */ + @InterfaceStability.Unstable @transient - private[sql] lazy val sessionState: SessionState = { + lazy val sessionState: SessionState = { SparkSession.reflect[SessionState, SparkSession]( SparkSession.sessionStateClassName(sparkContext.conf), self) @@ -337,8 +346,7 @@ class SparkSession private( val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) + SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) } @@ -364,8 +372,7 @@ class SparkSession private( */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { val attrSeq = getSchema(beanClass) - val beanInfo = Introspector.getBeanInfo(beanClass) - val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq) Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } @@ -613,7 +620,6 @@ class SparkSession private( * * @since 2.1.0 */ - @InterfaceStability.Stable def time[T](f: => T): T = { val start = System.nanoTime() val ret = f @@ -928,9 +934,19 @@ object SparkSession { defaultSession.set(null) } - private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + /** + * Returns the active SparkSession for the current thread, returned by the builder. + * + * @since 2.2.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) - private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the default SparkSession that is returned by the builder. + * + * @since 2.2.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** A global SQL listener used for the SQL UI. */ private[sql] val sqlListener = new AtomicReference[SQLListener]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index e56c33e4b512..a4c5bf756cd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,12 +47,14 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport + && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { if (enableHiveSupport) { logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - "Spark is not built with Hive; falling back to without Hive support.") + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + + "falling back to without Hive support.") } SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 4ca134700857..80138510dc9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -168,15 +168,16 @@ class CacheManager extends Logging { (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } - cachedData.foreach { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - case _ => // Do Nothing + cachedData.filter { + case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true + case _ => false + }.foreach { data => + val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) + if (dataIndex >= 0) { + data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) + cachedData.remove(dataIndex) + } + sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index b8ac070e3a95..aa578f4d2313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -102,12 +102,14 @@ case class OptimizeMetadataOnlyQuery( LocalRelation(partAttrs, partitionData.map(_.values)) case relation: CatalogRelation => - val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) - val partitionData = catalog.listPartitions(relation.catalogTable.identifier).map { p => + val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) + val caseInsensitiveProperties = + CaseInsensitiveMap(relation.tableMeta.storage.properties) + val timeZoneId = caseInsensitiveProperties.get("timeZone") + .getOrElse(conf.sessionLocalTimeZone) + val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - // TODO: use correct timezone for partition values. - Cast(Literal(p.spec(attr.name)), attr.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } LocalRelation(partAttrs, partitionData) @@ -135,8 +137,8 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) Some(AttributeSet(partAttrs), l) - case relation: CatalogRelation if relation.catalogTable.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) + case relation: CatalogRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) Some(AttributeSet(partAttrs), relation) case p @ Project(projectList, child) if projectList.forall(_.deterministic) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9d046c0766aa..6ec2f4d84086 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -125,8 +125,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => command.executeCollect().map(_.getString(1)) - case command: ExecutedCommandExec => - command.executeCollect().map(_.getString(0)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq // We need the types so we can output struct field names @@ -197,7 +195,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } - override def toString: String = { + override def toString: String = completeString(appendStats = false) + + def toStringWithStats: String = completeString(appendStats = true) + + private def completeString(appendStats: Boolean): String = { def output = Utils.truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") val analyzedPlan = Seq( @@ -205,12 +207,20 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { stringOrError(analyzed.treeString(verbose = true)) ).filter(_.nonEmpty).mkString("\n") + val optimizedPlanString = if (appendStats) { + // trigger to compute stats for logical plans + optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.treeString(verbose = true, addSuffix = true) + } else { + optimizedPlan.treeString(verbose = true) + } + s"""== Parsed Logical Plan == |${stringOrError(logical.treeString(verbose = true))} |== Analyzed Logical Plan == |$analyzedPlan |== Optimized Logical Plan == - |${stringOrError(optimizedPlan.treeString(verbose = true))} + |${stringOrError(optimizedPlanString)} |== Physical Plan == |${stringOrError(executedPlan.treeString(verbose = true))} """.stripMargin.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d50800235264..65df68868939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -22,16 +22,17 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -282,7 +283,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (statement == null) { null // This is enough since ParseException will raise later. } else if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) + ExplainCommand( + logicalPlan = statement, + extended = ctx.EXTENDED != null, + codegen = ctx.CODEGEN != null, + cost = ctx.COST != null) } else { ExplainCommand(OneRowRelation) } @@ -1441,4 +1446,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { reader, writer, schemaLess) } + + /** + * Create a clause for DISTRIBUTE BY. + */ + override protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, conf.numShufflePartitions) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 557181ebd959..20bf4925dbec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SaveMode, Strategy} +import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming._ @@ -244,6 +244,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming deduplicate operator. + */ + object StreamingDeduplicationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Deduplicate(keys, child, true) => + StreamingDeduplicateExec(keys, planLater(child)) :: Nil + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ @@ -332,8 +344,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions: Int = self.numPartitions - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil @@ -414,9 +424,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, nPartitions) => + case logical.RepartitionByExpression(expressions, child, numPartitions) => exchange.ShuffleExchange(HashPartitioning( - expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil + expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 2ead8f6baae6..c58474eba05d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -254,7 +254,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "") } } @@ -428,7 +429,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "*") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index d024a3673d4b..b89014ed8ef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -40,60 +37,40 @@ case class AnalyzeColumnCommand( val sessionState = sparkSession.sessionState val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = - EliminateSubqueryAliases(sparkSession.table(tableIdentWithDB).queryExecution.analyzed) - - // Compute total size - val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { - case catalogRel: CatalogRelation => - // This is a Hive serde format table - (catalogRel.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) - - case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - // This is a data source format table - (logicalRel.catalogTable.get, - AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) - - case otherRelation => - throw new AnalysisException("ANALYZE TABLE is not supported for " + - s"${otherRelation.nodeName}.") + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") } + val sizeInBytes = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) // Compute stats for each column - val (rowCount, newColStats) = - AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames) + val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = CatalogStatistics( sizeInBytes = sizeInBytes, rowCount = Some(rowCount), // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + sessionState.catalog.alterTable(tableMeta.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } -} - -object AnalyzeColumnCommand extends Logging { /** * Compute stats for the given columns. * @return (row count, map from column name to ColumnStats) - * - * This is visible for testing. */ - def computeColumnStats( + private def computeColumnStats( sparkSession: SparkSession, - tableName: String, - relation: LogicalPlan, + tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = AttributeSet(columnNames.map { col => @@ -105,7 +82,7 @@ object AnalyzeColumnCommand extends Logging { attributesToAnalyze.foreach { attr => if (!ColumnStat.supportsType(attr.dataType)) { throw new AnalysisException( - s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " + + s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") } } @@ -116,7 +93,7 @@ object AnalyzeColumnCommand extends Logging { // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 30b6cc7617cb..d2ea0cdf61aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -22,11 +22,9 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SessionState @@ -41,53 +39,39 @@ case class AnalyzeTableCommand( val sessionState = sparkSession.sessionState val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = - EliminateSubqueryAliases(sparkSession.table(tableIdentWithDB).queryExecution.analyzed) - - relation match { - case relation: CatalogRelation => - updateTableStats(relation.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) - - // data source tables have been converted into LogicalRelations - case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateTableStats(logicalRel.catalogTable.get, - AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) - - case otherRelation => - throw new AnalysisException("ANALYZE TABLE is not supported for " + - s"${otherRelation.nodeName}.") + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") } + val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) - def updateTableStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val oldTotalSize = catalogTable.stats.map(_.sizeInBytes.toLong).getOrElse(0L) - val oldRowCount = catalogTable.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) - var newStats: Option[CatalogStatistics] = None - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) - } - // We only set rowCount when noscan is false, because otherwise: - // 1. when total size is not changed, we don't need to alter the table; - // 2. when total size is changed, `oldRowCount` becomes invalid. - // This is to make sure that we only record the right statistics. - if (!noscan) { - val newRowCount = Dataset.ofRows(sparkSession, relation).count() - if (newRowCount >= 0 && newRowCount != oldRowCount) { - newStats = if (newStats.isDefined) { - newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) - } else { - Some(CatalogStatistics( - sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount)))) - } + val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) + val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + var newStats: Option[CatalogStatistics] = None + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) + } + // We only set rowCount when noscan is false, because otherwise: + // 1. when total size is not changed, we don't need to alter the table; + // 2. when total size is changed, `oldRowCount` becomes invalid. + // This is to make sure that we only record the right statistics. + if (!noscan) { + val newRowCount = sparkSession.table(tableIdentWithDB).count() + if (newRowCount >= 0 && newRowCount != oldRowCount) { + newStats = if (newStats.isDefined) { + newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) + } else { + Some(CatalogStatistics( + sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount)))) } } - // Update the metastore if the above statistics of the table are different from those - // recorded in the metastore. - if (newStats.isDefined) { - sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + } + // Update the metastore if the above statistics of the table are different from those + // recorded in the metastore. + if (newStats.isDefined) { + sessionState.catalog.alterTable(tableMeta.copy(stats = newStats)) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 58f507119325..5de45b159684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -88,11 +88,13 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { * @param logicalPlan plan to explain * @param extended whether to do extended explain or not * @param codegen whether to output generated code from whole-stage codegen or not + * @param cost whether to show cost information for operators. */ case class ExplainCommand( logicalPlan: LogicalPlan, extended: Boolean = false, - codegen: Boolean = false) + codegen: Boolean = false, + cost: Boolean = false) extends RunnableCommand { override val output: Seq[Attribute] = @@ -113,6 +115,8 @@ case class ExplainCommand( codegenString(queryExecution.executedPlan) } else if (extended) { queryExecution.toString + } else if (cost) { + queryExecution.toStringWithStats } else { queryExecution.simpleString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 5abd57947650..d835b521166a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -141,7 +141,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, mode, tableExists = true) + sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -151,7 +151,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, mode, tableExists = false) + sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index d646a215c38c..3e80916104bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -269,8 +269,8 @@ case class LoadDataCommand( } else { // Follow Hive's behavior: // If no schema or authority is provided with non-local inpath, - // we will use hadoop configuration "fs.default.name". - val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.default.name") + // we will use hadoop configuration "fs.defaultFS". + val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") val defaultFS = if (defaultFSConf == null) { new URI("") } else { @@ -312,7 +312,6 @@ case class LoadDataCommand( loadPath.toString, partition.get, isOverwrite, - holdDDLTime = false, inheritTableSpecs = true, isSrcLocal = isLocal) } else { @@ -320,7 +319,6 @@ case class LoadDataCommand( targetTable.identifier, loadPath.toString, isOverwrite, - holdDDLTime = false, isSrcLocal = isLocal) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 1235a4b12f1d..2068811661fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -72,7 +72,8 @@ class CatalogFileIndex( val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( - p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone), + path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 0762d1b7daae..54549f698aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils +import org.apache.spark.TaskContext + object CodecStreams { private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { val compressionCodecs = new CompressionCodecFactory(config) @@ -42,6 +44,16 @@ object CodecStreams { .getOrElse(inputStream) } + /** + * Creates an input stream from the string path and add a closure for the input stream to be + * closed on task completion. + */ + def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { + val inputStream = createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f3536548fc88..2871de14a747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -106,10 +106,13 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource + * @param fileStatusCache the shared cache for file statuses to speed up listing * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ - private def getOrInferFileFormatSchema(format: FileFormat): (StructType, StructType) = { + private def getOrInferFileFormatSchema( + format: FileFormat, + fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { // the operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below @@ -122,7 +125,7 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) } val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning @@ -279,28 +282,6 @@ case class DataSource( } } - /** - * Returns true if there is a single path that has a metadata log indicating which files should - * be read. - */ - def hasMetadata(path: Seq[String]): Boolean = { - path match { - case Seq(singlePath) => - try { - val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) - val res = fs.exists(metadataPath) - res - } catch { - case NonFatal(e) => - logWarning(s"Error while looking for metadata directory.") - false - } - case _ => false - } - } - /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] @@ -331,7 +312,9 @@ case class DataSource( // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. case (format: FileFormat, _) - if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => + if FileStreamSink.hasMetadata( + caseInsensitiveOptions.get("path").toSeq ++ paths, + sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { @@ -373,8 +356,9 @@ case class DataSource( } globPath }.toArray - + createHadoopRelation(format, globbedPaths) + case _ => throw new AnalysisException( s"$className is not a valid Spark SQL Data Source.") @@ -390,7 +374,9 @@ case class DataSource( */ def createHadoopRelation(format: FileFormat, globPaths: Array[Path]): BaseRelation = { - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) + val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f4292320e4bf..f694a0d6d724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -208,16 +208,17 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { /** - * Replaces [[SimpleCatalogRelation]] with data source table if its table provider is not hive. + * Replaces [[CatalogRelation]] with data source table if its table provider is not hive. */ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private def readDataSourceTable(table: CatalogTable): LogicalPlan = { + private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { + val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache val withHiveSupport = sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" - cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> _) val dataSource = @@ -233,19 +234,25 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] // TODO: improve `InMemoryCatalog` and remove this limitation. catalogTable = if (withHiveSupport) Some(table) else None) - LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), + LogicalRelation( + dataSource.resolveRelation(checkFilesExist = false), catalogTable = Some(table)) } - }) + }).asInstanceOf[LogicalRelation] + + // It's possible that the table schema is empty and need to be inferred at runtime. We should + // not specify expected outputs for this case. + val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) + plan.copy(expectedOutputAttributes = expectedOutputs) } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) - if DDLUtils.isDatasourceTable(s.metadata) => - i.copy(table = readDataSourceTable(s.metadata)) + case i @ InsertIntoTable(r: CatalogRelation, _, _, _, _) + if DDLUtils.isDatasourceTable(r.tableMeta) => + i.copy(table = readDataSourceTable(r)) - case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => - readDataSourceTable(s.metadata) + case r: CatalogRelation if DDLUtils.isDatasourceTable(r.tableMeta) => + readDataSourceTable(r) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index be13cbc51a9d..950e5ca0d621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,11 +37,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ @@ -64,12 +63,13 @@ object FileFormatWriter extends Logging { val serializableHadoopConf: SerializableConfiguration, val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], val dataColumns: Seq[Attribute], - val bucketSpec: Option[BucketSpec], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val timeZoneId: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -108,34 +108,72 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) + val allColumns = queryExecution.logical.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val bucketIdExpression = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + + val caseInsensitiveOptions = CaseInsensitiveMap(options) + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = queryExecution.logical.output, - partitionColumns = partitionColumns, + allColumns = allColumns, dataColumns = dataColumns, - bucketSpec = bucketSpec, + partitionColumns = partitionColumns, + bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, - maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + // the sort order doesn't matter + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) try { - val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, + val rdd = if (orderingMatched) { + queryExecution.toRdd + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = queryExecution.executedPlan).execute() + } + + val ret = sparkSession.sparkContext.runJob(rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -177,11 +215,11 @@ object FileFormatWriter extends Logging { val taskAttemptContext: TaskAttemptContext = { // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value - hadoopConf.set("mapred.job.id", jobId.toString) - hadoopConf.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - hadoopConf.set("mapred.task.id", taskAttemptId.toString) - hadoopConf.setBoolean("mapred.task.is.map", true) - hadoopConf.setInt("mapred.task.partition", 0) + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -189,7 +227,7 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) val writeTask = - if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { new DynamicPartitionWriteTask(description, taskAttemptContext, committer) @@ -287,36 +325,20 @@ object FileFormatWriter extends Logging { * multiple directories (partitions) or files (bucketing). */ private class DynamicPartitionWriteTask( - description: WriteJobDescription, + desc: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ - private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - description.partitionColumns.zipWithIndex.flatMap { case (c, i) => - // TODO: use correct timezone for partition values. + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ + private def partitionPathExpression: Seq[Expression] = { + desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), + Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil @@ -325,35 +347,46 @@ object FileFormatWriter extends Logging { } /** - * Open and returns a new OutputWriter given a partition key and optional bucket id. + * Opens a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param key vaues for fields consisting of partition keys for the current row - * @param partString a function that projects the partition values into a string + * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the + * current row. + * @param getPartitionPath a function that projects the partition values into a path string. * @param fileCounter the number of files that have been written in the past for this specific * partition. This is used to limit the max number of records written for a * single file. The value should start from 0. + * @param updatedPartitions the set of updated partition paths, we should add the new partition + * path of this writer to it. */ private def newOutputWriter( - key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { - val partDir = - if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) + partColsAndBucketId: InternalRow, + getPartitionPath: UnsafeProjection, + fileCounter: Int, + updatedPartitions: mutable.Set[String]): Unit = { + val partDir = if (desc.partitionColumns.isEmpty) { + None + } else { + Option(getPartitionPath(partColsAndBucketId).getString(0)) + } + partDir.foreach(updatedPartitions.add) - // If the bucket spec is defined, the bucket column is right after the partition columns - val bucketId = if (description.bucketSpec.isDefined) { - BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) + // If the bucketId expression is defined, the bucketId column is right after the partition + // columns. + val bucketId = if (desc.bucketIdExpression.isDefined) { + BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) } else { "" } // This must be in a form that matches our bucketing format. See BucketingUtils. val ext = f"$bucketId.c$fileCounter%03d" + - description.outputWriterFactory.getFileExtension(taskAttemptContext) + desc.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => - description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) case _ => None } @@ -363,80 +396,42 @@ object FileFormatWriter extends Logging { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - currentWriter = description.outputWriterFactory.newInstance( + currentWriter = desc.outputWriterFactory.newInstance( path = path, - dataSchema = description.dataColumns.toStructType, + dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } override def execute(iter: Iterator[InternalRow]): Set[String] = { - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) - - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(description.dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iter.hasNext) { - val currentRow = iter.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + val getPartitionColsAndBucketId = UnsafeProjection.create( + desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. + val getPartPath = UnsafeProjection.create( + Seq(Concat(partitionPathExpression)), desc.partitionColumns) - val sortedIterator = sorter.sortedIterator() + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 - var currentKey: UnsafeRow = null + var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - // See a new key - write to a new partition (new file). - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") + for (row <- iter) { + val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) + if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + currentPartColsAndBucketId = nextPartColsAndBucketId.copy() + logDebug(s"Writing partition: $currentPartColsAndBucketId") recordsInFile = 0 fileCounter = 0 releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) - val partitionPath = getPartitionStringFunc(currentKey).getString(0) - if (partitionPath.nonEmpty) { - updatedPartitions.add(partitionPath) - } - } else if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) + } else if (desc.maxRecordsPerFile > 0 && + recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. recordsInFile = 0 @@ -445,10 +440,10 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - currentWriter.write(sortedIterator.getValue) + currentWriter.write(getOutputRow(row)) recordsInFile += 1 } releaseResources() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 7531f0ae02e7..ee4d0863d977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -66,8 +66,8 @@ class InMemoryFileIndex( } override def refresh(): Unit = { - refresh0() fileStatusCache.invalidateAll() + refresh0() } private def refresh0(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 652bcc833193..19b51d4d9530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -147,7 +147,10 @@ case class InsertIntoHadoopFsRelationCommand( refreshFunction = refreshPartitionsCallback, options = options) + // refresh cached files in FileIndex fileIndex.foreach(_.refresh()) + // refresh data cache if table is cached + sparkSession.catalog.refreshByPath(outputPath.toString) } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 75f87a5503b8..c8097a7fabc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,7 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -125,22 +125,27 @@ abstract class PartitioningAwareFileIndex( val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => files.exists(f => isDataPath(f.getPath)) }.keys.toSeq + + val caseInsensitiveOptions = CaseInsensitiveMap(parameters) + val timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) + userPartitionSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, typeInference = false, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => - // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Option(timeZoneId)).eval() }: _*) } @@ -151,7 +156,8 @@ abstract class PartitioningAwareFileIndex( PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) } } @@ -300,7 +306,7 @@ object PartitioningAwareFileIndex extends Logging { sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size < sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { return paths.map { path => (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bad59961ace1..09876bbc2f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -91,10 +93,19 @@ object PartitioningUtils { private[datasources] def parsePartitions( paths: Seq[Path], typeInference: Boolean, - basePaths: Set[Path]): PartitionSpec = { + basePaths: Set[Path], + timeZoneId: String): PartitionSpec = { + parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId)) + } + + private[datasources] def parsePartitions( + paths: Seq[Path], + typeInference: Boolean, + basePaths: Set[Path], + timeZone: TimeZone): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths) + parsePartition(path, typeInference, basePaths, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +184,8 @@ object PartitioningUtils { private[datasources] def parsePartition( path: Path, typeInference: Boolean, - basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { + basePaths: Set[Path], + timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -194,7 +206,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -226,7 +238,8 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - typeInference: Boolean): Option[(String, Literal)] = { + typeInference: Boolean, + timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -237,7 +250,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) Some(columnName -> literal) } } @@ -370,7 +383,8 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - typeInference: Boolean): Literal = { + typeInference: Boolean, + timeZone: TimeZone): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. val bigDecimal = new JBigDecimal(raw) @@ -390,8 +404,16 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try(Literal(JDate.valueOf(raw)))) - .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) + .orElse(Try( + Literal.create( + DateTimeUtils.getThreadLocalTimestampFormat(timeZone) + .parse(unescapePathName(raw)).getTime * 1000L, + TimestampType))) + .orElse(Try( + Literal.create( + DateTimeUtils.millisToDays( + DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), + DateType))) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala new file mode 100644 index 000000000000..73e6abc6dad3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.InputStream +import java.nio.charset.{Charset, StandardCharsets} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.types.StructType + +/** + * Common functions for parsing CSV files + */ +abstract class CSVDataSource extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into [[InternalRow]] instances. + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] + + /** + * Infers the schema from `inputPaths` files. + */ + def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + protected def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } +} + +object CSVDataSource { + def apply(options: CSVOptions): CSVDataSource = { + if (options.wholeFile) { + WholeFileCSVDataSource + } else { + TextInputCSVDataSource + } + } +} + +object TextInputCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = true + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + val lines = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + } + } + + val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) + } else { + val charset = options.charset + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(",")) + .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) + } + } +} + +object WholeFileCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = false + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + UnivocityParser.parseStream( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + parsedOptions.headerFlag, + parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) + val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + false, + new CsvParser(parsedOptions.asParserSettings)) + }.take(1).headOption + + if (maybeFirstRow.isDefined) { + val firstRow = maybeFirstRow.get + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } else { + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) + } + } + + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): RDD[PortableDataStream] = { + val paths = inputPaths.map(_.getPath) + val name = paths.mkString(",") + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + val conf = job.getConfiguration + + val rdd = new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + + // Only returns `PortableDataStream`s without paths. + rdd.setName(s"CSVFile: $name").values + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index f0075cf4e7e4..c912b295bdc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.{Charset, StandardCharsets} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -43,22 +37,26 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" - override def toString: String = "CSV" - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvDataSource = CSVDataSource(parsedOptions) + csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { require(files.nonEmpty, "Cannot infer schema from an empty set of files") + + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, files) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) + CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) } override def prepareWrite( @@ -95,58 +93,35 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - + CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - (file: PartitionedFile) => { - val lines = { - val conf = broadcastedHadoopConf.value.value - val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.map { line => - new String(line.getBytes, 0, line.getLength, csvOptions.charset) - } + val parsedOptions = new CSVOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") } - - val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, csvOptions) - } else { - lines - } - - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions) - val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) - filteredLines.flatMap(parser.parse) } - } - - private def createBaseDataset( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[FileStatus]): Dataset[String] = { - val pathValues = inputPaths.map(_.getPath().toString) - if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - // Fix for SPARK-19340. resolveRelation replaces with createHadoopRelation - // to avoid pattern resolution for already resolved file path - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = pathValues, - className = classOf[TextFileFormat].getName - ).createHadoopRelation(new TextFileFormat, inputPaths.map(_.getPath).toArray)) - .select("value").as[String](Encoders.STRING) - } else { - val charset = options.charset - val rdd = sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](pathValues.mkString(",")) - .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) - sparkSession.createDataset(rdd)(Encoders.STRING) + (file: PartitionedFile) => { + val conf = broadcastedHadoopConf.value.value + val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) } } + + override def toString: String = "CSV" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3fa30fe2401e..b64d71bb4eef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -21,11 +21,9 @@ import java.math.BigDecimal import scala.util.control.Exception._ -import com.univocity.parsers.csv.CsvParser - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ private[csv] object CSVInferSchema { @@ -37,24 +35,13 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - csv: Dataset[String], - caseSensitive: Boolean, + tokenRDD: RDD[Array[String]], + header: Array[String], options: CSVOptions): StructType = { - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first() - val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) - val header = makeSafeHeader(firstRow, caseSensitive, options) - val fields = if (options.inferSchemaFlag) { - val tokenRdd = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options) - val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options) - val parser = new CsvParser(options.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { @@ -71,44 +58,6 @@ private[csv] object CSVInferSchema { StructType(fields) } - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } - private def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index b7fbaa4f44a6..50503385ad6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,11 +27,20 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} private[csv] class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String) + @transient private val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { - def this(parameters: Map[String, String], defaultTimeZoneId: String) = - this(CaseInsensitiveMap(parameters), defaultTimeZoneId) + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) @@ -95,6 +104,9 @@ private[csv] class CSVOptions( val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + val nullValue = parameters.getOrElse("nullValue", "") val nanValue = parameters.getOrElse("nanValue", "NaN") @@ -118,6 +130,8 @@ private[csv] class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 2e409b3f5fbf..3b3b87e4354d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import java.io.InputStream import java.math.BigDecimal import java.text.NumberFormat import java.util.Locale @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] class UnivocityParser( schema: StructType, requiredSchema: StructType, - options: CSVOptions) extends Logging { + private val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -45,24 +46,85 @@ private[csv] class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val valueConverters = - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } + + private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val parser = new CsvParser(options.asParserSettings) + private val tokenizer = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 private val row = new GenericInternalRow(requiredSchema.length) - private val indexArr: Array[Int] = { - val fields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } - fields.map(schema.indexOf(_: StructField)).toArray + // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field + // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. + private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + + // This parser loads an `tokenIndexArr`-th position value in input tokens, + // then put the value in `row(rowIndexArr)`. + // + // For example, let's say there is CSV data as below: + // + // a,b,c + // 1,2,A + // + // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` + // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need + // to map those values below: + // + // required schema - ["c", "b", "_unparsed", "a"] + // CSV data schema - ["a", "b", "c"] + // required CSV data schema - ["c", "b", "a"] + // + // with the input tokens, + // + // input tokens - [1, 2, "A"] + // + // Each input token is placed in each output row's position by mapping these. In this case, + // + // output row - ["A", 2, null, 1] + // + // In more details, + // - `valueConverters`, input tokens - CSV data schema + // `valueConverters` keeps the positions of input token indices (by its index) to each + // value's converter (by its value) in an order of CSV data schema. In this case, + // [string->int, string->int, string->string]. + // + // - `tokenIndexArr`, input tokens - required CSV data schema + // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered + // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. + // + // - `rowIndexArr`, input tokens - required schema + // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered + // field indices given the required schema (by its value). In this case, [0, 1, 3]. + private val valueConverters: Array[ValueConverter] = + dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means + // the fields that we should try to convert. + private val reorderedFields = if (options.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) + } else { + requiredSchema + } + + private val tokenIndexArr: Array[Int] = { + reorderedFields + .filter(_.name != options.columnNameOfCorruptRecord) + .map(f => dataSchema.indexOf(f)).toArray + } + + private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { + val corrFieldIndex = corruptFieldIndex.get + reorderedFields.indices.filter(_ != corrFieldIndex).toArray + } else { + reorderedFields.indices.toArray } /** @@ -148,6 +210,7 @@ private[csv] class UnivocityParser( case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) + // We don't actually hit this exception though, we keep it for understandability case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") } @@ -167,21 +230,23 @@ private[csv] class UnivocityParser( } /** - * Parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the + * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = { - convertWithParseMode(parser.parseLine(input)) { tokens => + def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): Option[InternalRow] = { + convertWithParseMode(tokens) { tokens => var i: Int = 0 - while (i < indexArr.length) { - val pos = indexArr(i) + while (i < tokenIndexArr.length) { // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. - val value = valueConverters(pos).apply(tokens(pos)) + val from = tokenIndexArr(i) + val to = rowIndexArr(i) + val value = valueConverters(from).apply(tokens(from)) if (i < requiredSchema.length) { - row(i) = value + row(to) = value } i += 1 } @@ -191,7 +256,7 @@ private[csv] class UnivocityParser( private def convertWithParseMode( tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && schema.length != tokens.length) { + if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") } @@ -202,14 +267,24 @@ private[csv] class UnivocityParser( } numMalformedRecords += 1 None - } else if (options.failFast && schema.length != tokens.length) { + } else if (options.failFast && dataSchema.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + s"${tokens.mkString(options.delimiter.toString)}") } else { - val checkedTokens = if (options.permissive && schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) - } else if (options.permissive && schema.length < tokens.length) { - tokens.take(schema.length) + // If a length of parsed tokens is not equal to expected one, it makes the length the same + // with the expected. If the length is shorter, it adds extra tokens in the tail. + // If longer, it drops extra tokens. + // + // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, + // we probably need to treat it as a malformed record. + // See an URL below for related discussions: + // https://github.com/apache/spark/pull/16928#discussion_r102657214 + val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { + if (dataSchema.length > tokens.length) { + tokens ++ new Array[String](dataSchema.length - tokens.length) + } else { + tokens.take(dataSchema.length) + } } else { tokens } @@ -217,6 +292,10 @@ private[csv] class UnivocityParser( try { Some(convert(checkedTokens)) } catch { + case NonFatal(e) if options.permissive => + val row = new GenericInternalRow(requiredSchema.length) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) + Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning("Parse exception. " + @@ -233,3 +312,75 @@ private[csv] class UnivocityParser( } } } + +private[csv] object UnivocityParser { + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of tokens. + */ + def tokenizeStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser): Iterator[Array[String]] = { + convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) + } + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of rows. + */ + def parseStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val tokenizer = parser.tokenizer + convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + parser.convert(tokens) + }.flatten + } + + private def convertStream[T]( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer.beginParsing(inputStream) + private var nextRecord = { + if (shouldDropHeader) { + tokenizer.parseNext() + } + tokenizer.parseNext() + } + + override def hasNext: Boolean = nextRecord != null + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + val curRecord = convert(nextRecord) + nextRecord = tokenizer.parseNext() + curRecord + } + } + + /** + * Parses an iterator that contains CSV strings and turns it into an iterator of rows. + */ + def parseIterator( + lines: Iterator[String], + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val options = parser.options + + val linesWithoutHeader = if (shouldDropHeader) { + // Note that if there are only comments in the first block, the header would probably + // be not dropped. + CSVUtils.dropHeaderLine(lines, options) + } else { + lines + } + + val filteredLines: Iterator[String] = + CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + filteredLines.flatMap(line => parser.parse(line)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3e984effcb8d..18843bfc307b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.InputStream - import scala.reflect.ClassTag import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} @@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { } } - private def createInputStream(config: Configuration, path: String): InputStream = { - val inputStream = CodecStreams.createInputStream(config, new Path(path)) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) - inputStream - } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, - createInputStream(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) } override def readFile( @@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { - Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + Utils.tryWithResource { + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } parser.parse( - createInputStream(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), CreateJacksonParser.inputStream, partitionedFileString).toIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 2cbf4ea7beac..902fee5a7e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -102,6 +102,15 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + (file: PartitionedFile) => { val parser = new JacksonParser(requiredSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1c3e7c6d5223..4d781b96abac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -52,8 +52,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } - val plan = LogicalRelation(dataSource.resolveRelation()) - u.alias.map(a => SubqueryAlias(a, plan, None)).getOrElse(plan) + LogicalRelation(dataSource.resolveRelation()) } catch { case _: ClassNotFoundException => u case e: Exception => @@ -380,7 +379,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: CatalogRelation => - val metadata = relation.catalogTable + val metadata = relation.tableMeta preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) case LogicalRelation(h: HadoopFsRelation, _, catalogTable) => val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 0dbe2a71ed3b..07ec4e9429e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -25,9 +28,31 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} -object FileStreamSink { +object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. val metadataDir = "_spark_metadata" + + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = { + path match { + case Seq(singlePath) => + try { + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(hadoopConf) + val metadataPath = new Path(hdfsPath, metadataDir) + val res = fs.exists(metadataPath) + res + } catch { + case NonFatal(e) => + logWarning(s"Error while looking for metadata directory.") + false + } + case _ => false + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 39c0b4979687..6a7263ca45d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -43,8 +43,10 @@ class FileStreamSource( private val sourceOptions = new FileStreamOptions(options) + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val qualifiedBasePath: Path = { - val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = new Path(path).getFileSystem(hadoopConf) fs.makeQualified(new Path(path)) // can contains glob patterns } @@ -157,14 +159,65 @@ class FileStreamSource( checkFilesExist = false))) } + /** + * If the source has a metadata log indicating which files should be read, then we should use it. + * Only when user gives a non-glob path that will we figure out whether the source has some + * metadata log + * + * None means we don't know at the moment + * Some(true) means we know for sure the source DOES have metadata + * Some(false) means we know for sure the source DOSE NOT have metadata + */ + @volatile private[sql] var sourceHasMetadata: Option[Boolean] = + if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None + + private def allFilesUsingInMemoryFileIndex() = { + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) + fileIndex.allFiles() + } + + private def allFilesUsingMetadataLogFileIndex() = { + // Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a + // non-glob path + new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles() + } + /** * Returns a list of files found, sorted by their timestamp. */ private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime - val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) - val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) - val files = catalog.allFiles().sortBy(_.getModificationTime)(fileSortOrder).map { status => + + var allFiles: Seq[FileStatus] = null + sourceHasMetadata match { + case None => + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + allFiles = allFilesUsingInMemoryFileIndex() + if (allFiles.isEmpty) { + // we still cannot decide + } else { + // decide what to use for future rounds + // double check whether source has metadata, preventing the extreme corner case that + // metadata log and data files are only generated after the previous + // `FileStreamSink.hasMetadata` check + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + sourceHasMetadata = Some(false) + // `allFiles` have already been fetched using InMemoryFileIndex in this round + } + } + } + case Some(true) => allFiles = allFilesUsingMetadataLogFileIndex() + case Some(false) => allFiles = allFilesUsingInMemoryFileIndex() + } + + val files = allFiles.sortBy(_.getModificationTime)(fileSortOrder).map { status => (status.getPath.toUri.toString, status.getModificationTime) } val endTime = System.nanoTime diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a3e108b29eda..ffdcd9b19d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -45,6 +45,7 @@ class IncrementalExecution( sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: + sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. @@ -93,6 +94,15 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + + StreamingDeduplicateExec( + keys, + child, + Some(stateId), + Some(currentEventTimeWatermark)) case MapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 4bd6431cbe11..70912d13ae45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{InterruptedIOException, IOException} import java.util.UUID import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -37,6 +38,12 @@ import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} +/** States for [[StreamExecution]]'s lifecycle. */ +trait State +case object INITIALIZING extends State +case object ACTIVE extends State +case object TERMINATED extends State + /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any @@ -298,7 +305,14 @@ class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case _: InterruptedException if state.get == TERMINATED => // interrupted by stop() + case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED => + // interrupted by stop() + updateStatusMessage("Stopped") + case e: IOException if e.getMessage != null + && e.getMessage.startsWith(classOf[InterruptedException].getName) + && state.get == TERMINATED => + // This is a workaround for HADOOP-12074: `Shell.runCommand` converts `InterruptedException` + // to `new IOException(ie.toString())` before Hadoop 2.8. updateStatusMessage("Stopped") case e: Throwable => streamDeathCause = new StreamingQueryException( @@ -321,6 +335,7 @@ class StreamExecution( initializationLatch.countDown() try { + stopSources() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -558,6 +573,18 @@ class StreamExecution( sparkSession.streams.postListenerEvent(event) } + /** Stops all streaming sources safely. */ + private def stopSources(): Unit = { + uniqueSources.foreach { source => + try { + source.stop() + } catch { + case NonFatal(e) => + logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e) + } + } + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -570,7 +597,6 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } - uniqueSources.foreach(_.stop()) logInfo(s"Query $prettyIdString was stopped") } @@ -709,10 +735,6 @@ class StreamExecution( } } - trait State - case object INITIALIZING extends State - case object ACTIVE extends State - case object TERMINATED extends State } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index a2153d27e9fe..4207013c3f75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -75,6 +75,19 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Override the parent `postToAll` to remove the query id from `activeQueryRunIds` after all + * the listeners process `QueryTerminatedEvent`. (SPARK-19594) + */ + override def postToAll(event: Event): Unit = { + super.postToAll(event) + event match { + case t: QueryTerminatedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds -= t.runId } + case _ => + } + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => @@ -112,7 +125,6 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) case queryTerminated: QueryTerminatedEvent => if (shouldReport(queryTerminated.runId)) { listener.onQueryTerminated(queryTerminated) - activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } } case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 900d92bc0d95..58bff27a05bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -46,8 +46,8 @@ object TextSocketSource { * support for fault recovery and keeping all of the text read in memory forever. */ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging -{ + extends Source with Logging { + @GuardedBy("this") private var socket: Socket = null @@ -168,6 +168,8 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo socket = null } } + + override def toString: String = s"TextSocketSource[host: $host, port: $port]" } class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 61eb601a18c3..ab1204a750fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -274,7 +274,18 @@ private[state] class HDFSBackedStateStoreProvider( private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { synchronized { val finalDeltaFile = deltaFile(newVersion) - if (!fs.rename(tempDeltaFile, finalDeltaFile)) { + + // scalastyle:off + // Renaming a file atop an existing one fails on HDFS + // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). + // Hence we should either skip the rename step or delete the target file. Because deleting the + // target file will break speculation, skipping the rename step is the only choice. It's still + // semantically correct because Structured Streaming requires rerunning a batch should + // generate the same output. (SPARK-19677) + // scalastyle:on + if (fs.exists(finalDeltaFile)) { + fs.delete(tempDeltaFile, true) + } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") } loadedMaps.put(newVersion, map) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 129245257459..d92529748b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, NullType, StructType} import org.apache.spark.util.CompletionIterator @@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator { "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) } +/** An operator that supports watermark. */ +trait WatermarkSupport extends SparkPlan { + + /** The keys that may have a watermark attribute. */ + def keyExpressions: Seq[Attribute] + + /** The watermark value. */ + def eventTimeWatermark: Option[Long] + + /** Generate a predicate that matches data older than the watermark */ + lazy val watermarkPredicate: Option[Predicate] = { + val optionalWatermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + + optionalWatermarkAttribute.map { watermarkAttribute => + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + newPredicate(evictionExpression, keyExpressions) + } + } +} + /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -76,7 +109,7 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StateStoreReader { + extends UnaryExecNode with StateStoreReader { override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -113,31 +146,7 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StateStoreWriter { - - /** Generate a predicate that matches data older than the watermark */ - private lazy val watermarkPredicate: Option[Predicate] = { - val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) - - optionalWatermarkAttribute.map { watermarkAttribute => - // If we are evicting based on a window, use the end of the window. Otherwise just - // use the attribute itself. - val evictionExpression = - if (watermarkAttribute.dataType.isInstanceOf[StructType]) { - LessThanOrEqual( - GetStructField(watermarkAttribute, 1), - Literal(eventTimeWatermark.get * 1000)) - } else { - LessThanOrEqual( - watermarkAttribute, - Literal(eventTimeWatermark.get * 1000)) - } - - logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) - } - } + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -146,8 +155,8 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -262,8 +271,8 @@ case class MapGroupsWithStateExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, groupingAttributes.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -321,3 +330,70 @@ case class MapGroupsWithStateExec( } } } + + +/** Physical operator for executing streaming Deduplicate. */ +case class StreamingDeduplicateExec( + keyExpressions: Seq[Attribute], + child: SparkPlan, + stateId: Option[OperatorStateId] = None, + eventTimeWatermark: Option[Long] = None) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(keyExpressions) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + val result = baseIterator.filter { r => + val row = r.asInstanceOf[UnsafeRow] + val key = getKey(row) + val value = store.get(key) + if (value.isEmpty) { + store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + numUpdatedStateRows += 1 + numOutputRows += 1 + true + } else { + // Drop duplicated rows + false + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + watermarkPredicate.foreach(f => store.remove(f.eval _)) + store.commit() + numTotalStateRows += store.numKeys() + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +object StreamingDeduplicateExec { + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5daf21595d8a..12d3bc9281f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -343,10 +343,13 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging { accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { (accumulatorUpdate._1, accumulatorUpdate._2) } - }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + } val driverUpdates = executionUIData.driverAccumUpdates.toSeq - mergeAccumulatorUpdates(accumulatorUpdates ++ driverUpdates, accumulatorId => + val totalUpdates = (accumulatorUpdates ++ driverUpdates).filter { + case (id, _) => executionUIData.accumulatorMetrics.contains(id) + } + mergeAccumulatorUpdates(totalUpdates, accumulatorId => executionUIData.accumulatorMetrics(accumulatorId).metricType) case None => // This execution has been dropped diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc0f13040693..461dfe3a66e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -402,11 +402,13 @@ object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") - .doc("The maximum number of files allowed for listing files at driver side. If the number " + - "of detected files exceeds this value during partition discovery, it tries to list the " + + .doc("The maximum number of paths allowed for listing files at driver side. If the number " + + "of detected paths exceeds this value during partition discovery, it tries to list the " + "files with another Spark distributed job. This applies to Parquet, ORC, CSV, JSON and " + "LibSVM data sources.") .intConf + .checkValue(parallel => parallel >= 0, "The maximum number of paths allowed for listing " + + "files at driver side must not be negative") .createWithDefault(32) val PARALLEL_PARTITION_DISCOVERY_PARALLELISM = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 8de95fe64e66..bce84de45c3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -39,12 +39,15 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // Load hive-site.xml into hadoopConf and determine the warehouse path we want to use, based on // the config from both hive and Spark SQL. Finally set the warehouse config value to sparkConf. - val warehousePath = { + val warehousePath: String = { val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") if (configFile != null) { + logInfo(s"loading hive config file: $configFile") sparkContext.hadoopConfiguration.addResource(configFile) } + // hive.metastore.warehouse.dir only stay in hadoopConf + sparkContext.conf.remove("hive.metastore.warehouse.dir") // Set the Hive metastore warehouse path to the one we use val hiveWarehouseDir = sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") if (hiveWarehouseDir != null && !sparkContext.conf.contains(WAREHOUSE_PATH.key)) { @@ -61,10 +64,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // When neither spark.sql.warehouse.dir nor hive.metastore.warehouse.dir is set, // we will set hive.metastore.warehouse.dir to the default value of spark.sql.warehouse.dir. val sparkWarehouseDir = sparkContext.conf.get(WAREHOUSE_PATH) - sparkContext.conf.set("hive.metastore.warehouse.dir", sparkWarehouseDir) + logInfo(s"Setting hive.metastore.warehouse.dir ('$hiveWarehouseDir') to the value of " + + s"${WAREHOUSE_PATH.key} ('$sparkWarehouseDir').") + sparkContext.hadoopConfiguration.set("hive.metastore.warehouse.dir", sparkWarehouseDir) sparkWarehouseDir } - } logInfo(s"Warehouse path is '$warehousePath'.") @@ -103,7 +107,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A manager for global temporary views. */ - val globalTempViewManager = { + val globalTempViewManager: GlobalTempViewManager = { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 99943944f3c6..aed8074a64d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -143,8 +143,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Loads a JSON file stream and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -168,8 +168,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * during parsing. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
    • + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -245,12 +248,20 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    *
  • + *
  • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf8ff61eae39..eb4d76c6ab03 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; @@ -146,13 +147,13 @@ public void dataFrameRDDOperations() { @Test public void applySchemaToJSON() { - JavaRDD jsonRDD = jsc.parallelize(Arrays.asList( + Dataset jsonDS = spark.createDataset(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + - "\"boolean\":false, \"null\":null}")); + "\"boolean\":false, \"null\":null}"), Encoders.STRING()); List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); @@ -183,14 +184,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - Dataset df1 = spark.read().json(jsonRDD); + Dataset df1 = spark.read().json(jsonDS); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.createOrReplaceTempView("jsonTable1"); List actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset df2 = spark.read().schema(expectedSchema).json(jsonRDD); + Dataset df2 = spark.read().schema(expectedSchema).json(jsonDS); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.createOrReplaceTempView("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c3b94a44c2e9..be8d95d0d912 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -397,4 +397,30 @@ public void testBloomFilter() { Assert.assertTrue(filter4.mightContain(i * 3)); } } + + public static class BeanWithoutGetter implements Serializable { + private String a; + + public void setA(String a) { + this.a = a; + } + } + + @Test + public void testBeanWithoutGetter() { + BeanWithoutGetter bean = new BeanWithoutGetter(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataFrame(data, BeanWithoutGetter.class); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } + + @Test + public void testJsonRDDToDataFrame() { + // This is a test for the deprecated API in SPARK-15615. + JavaRDD rdd = jsc.parallelize(Arrays.asList("{\"a\": 2}")); + Dataset df = spark.read().json(rdd); + Assert.assertEquals(1L, df.count()); + Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0)); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java index d3769a74b978..539976d5af46 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java @@ -88,7 +88,7 @@ public Encoder outputEncoder() { @Test public void testTypedAggregationAverage() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.avg(value -> (double)(value._2() * 2))); + Dataset> agged = grouped.agg(typed.avg(value -> value._2() * 2.0)); Assert.assertEquals( Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), agged.collectAsList()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 577672ca8e08..e3b0e37ccab0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -110,7 +110,8 @@ public void testCommonOperation() { Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map((MapFunction) v -> v.length(), Encoders.INT()); + Dataset mapped = + ds.map((MapFunction) String::length, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { @@ -157,17 +158,17 @@ public void testReduce() { public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = spark.createDataset(data, Encoders.STRING()); - KeyValueGroupedDataset grouped = ds.groupByKey( - (MapFunction) v -> v.length(), - Encoders.INT()); + KeyValueGroupedDataset grouped = + ds.groupByKey((MapFunction) String::length, Encoders.INT()); - Dataset mapped = grouped.mapGroups((MapGroupsFunction) (key, values) -> { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return sb.toString(); - }, Encoders.STRING()); + Dataset mapped = grouped.mapGroups( + (MapGroupsFunction) (key, values) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); @@ -209,7 +210,8 @@ public void testGroupBy() { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); - Dataset> reduced = grouped.reduceGroups((ReduceFunction) (v1, v2) -> v1 + v2); + Dataset> reduced = + grouped.reduceGroups((ReduceFunction) (v1, v2) -> v1 + v2); Assert.assertEquals( asSet(tuple2(1, "a"), tuple2(3, "foobar")), @@ -1276,4 +1278,15 @@ public void test() { spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class)); ds.collectAsList(); } + + public static class EmptyBean implements Serializable {} + + @Test + public void testEmptyBean() { + EmptyBean bean = new EmptyBean(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataset(data, Encoders.bean(EmptyBean.class)); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java index 6941c86dfcd4..127d272579a6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java @@ -29,8 +29,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -40,7 +38,6 @@ public class JavaSaveLoadSuite { private transient SparkSession spark; - private transient JavaSparkContext jsc; File path; Dataset df; @@ -58,7 +55,6 @@ public void setUp() throws IOException { .master("local[*]") .appName("testing") .getOrCreate(); - jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -70,8 +66,8 @@ public void setUp() throws IOException { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = jsc.parallelize(jsonObjects); - df = spark.read().json(rdd); + Dataset ds = spark.createDataset(jsonObjects, Encoders.STRING()); + df = spark.read().json(ds); df.createOrReplaceTempView("jsonTable"); } diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql new file mode 100644 index 000000000000..1caa45c66749 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql @@ -0,0 +1,36 @@ +-- Negative testcases for column resolution +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +-- Negative tests: column resolution scenarios with ambiguous cases in join queries +SET spark.sql.crossJoin.enabled = true; +USE mydb1; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +USE mydb2; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +SELECT db1.t1.i1 FROM t1, mydb2.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Negative tests +USE mydb1; +SELECT mydb1.t1 FROM t1; +SELECT t1.x.y.* FROM t1; +SELECT t1 FROM mydb1.t1; +USE mydb2; +SELECT mydb1.t1.i1 FROM t1; + +-- reset +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql new file mode 100644 index 000000000000..d3f928751757 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -0,0 +1,25 @@ +-- Tests for qualified column names for the view code-path +-- Test scenario with Temporary view +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1; +SELECT view1.* FROM view1; +SELECT * FROM view1; +SELECT view1.i1 FROM view1; +SELECT i1 FROM view1; +SELECT a.i1 FROM view1 AS a; +SELECT i1 FROM view1 AS a; +-- cleanup +DROP VIEW view1; + +-- Test scenario with Global Temp view +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; +SELECT * FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.* FROM global_temp.view1; +SELECT i1 FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.i1 FROM global_temp.view1; +SELECT view1.i1 FROM global_temp.view1; +SELECT a.i1 FROM global_temp.view1 AS a; +SELECT i1 FROM global_temp.view1 AS a; +-- cleanup +DROP VIEW global_temp.view1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql new file mode 100644 index 000000000000..79e90ad3de91 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -0,0 +1,88 @@ +-- Tests covering different scenarios with qualified column names +-- Scenario: column resolution scenarios with datasource table +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +USE mydb1; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +USE mydb2; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +-- Scenario: resolve fully qualified table name in star expansion +USE mydb1; +SELECT t1.* FROM t1; +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +USE mydb2; +SELECT t1.* FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +SELECT a.* FROM mydb1.t1 AS a; + +-- Scenario: resolve in case of subquery + +USE mydb1; +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2); +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3); + +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); + +-- TODO: Support this scenario +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); + +-- Scenario: column resolution scenarios in join queries +SET spark.sql.crossJoin.enabled = true; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb2.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; + +USE mydb2; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Scenario: Table with struct column +USE mydb1; +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet; +INSERT INTO t5 VALUES(1, (2, 3)); +SELECT t5.i1 FROM t5; +SELECT t5.t5.i1 FROM t5; +SELECT t5.t5.i1 FROM mydb1.t5; +SELECT t5.i1 FROM mydb1.t5; +SELECT t5.* FROM mydb1.t5; +SELECT t5.t5.* FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i1 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i2 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.* FROM mydb1.t5; + +-- Cleanup and Reset +USE default; +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 5107fa4d5553..b3ec956cd178 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -46,3 +46,6 @@ select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) -- error reporting: aggregate expression select * from values ("one", count(1)), ("two", 2) as data(a, b); + +-- string to timestamp +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql new file mode 100644 index 000000000000..38739cb95058 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql @@ -0,0 +1,17 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); + +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2; + +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4; + +-- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index b10c41929cda..880175fd7add 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -79,7 +79,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC; +ORDER BY t1a DESC, t3b DESC; -- TC 01.03 SELECT Count(DISTINCT(t1a)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql index 6b9e8bf2f362..5c371d2305ac 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql @@ -287,7 +287,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC; +ORDER BY t1c DESC, t1a DESC; -- TC 01.08 SELECT t1a, @@ -351,7 +351,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last; +ORDER BY t1c DESC NULLS last, t1a DESC; -- TC 01.11 SELECT * @@ -468,5 +468,5 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST; +ORDER BY t1c DESC NULLS LAST, t1i; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index 505366b7acd4..e09b91f18de0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -85,7 +85,7 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a; +ORDER BY t1a, t3b; -- TC 01.03 SELECT t1a, diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 59eb56920cdc..ba8bc936f0c7 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -196,7 +196,7 @@ SET spark.sql.caseSensitive=false -- !query 19 schema struct -- !query 19 output -spark.sql.caseSensitive +spark.sql.caseSensitive false -- !query 20 @@ -212,7 +212,7 @@ SET spark.sql.caseSensitive=true -- !query 21 schema struct -- !query 21 output -spark.sql.caseSensitive +spark.sql.caseSensitive true -- !query 22 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out new file mode 100644 index 000000000000..60bd8e9cc99d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -0,0 +1,240 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SET spark.sql.crossJoin.enabled = true +-- !query 6 schema +struct +-- !query 6 output +spark.sql.crossJoin.enabled true + + +-- !query 7 +USE mydb1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +SELECT i1 FROM t1, mydb1.t1 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 9 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 10 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 11 +SELECT i1 FROM t1, mydb2.t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 12 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1, mydb1.t1 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 15 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 16 +SELECT i1 FROM t1, mydb2.t1 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 17 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 18 +SELECT db1.t1.i1 FROM t1, mydb2.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 19 +SET spark.sql.crossJoin.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.crossJoin.enabled false + + +-- !query 20 +USE mydb1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SELECT mydb1.t1 FROM t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 22 +SELECT t1.x.y.* FROM t1 +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 't1.x.y.*' give input columns 'i1'; + + +-- !query 23 +SELECT t1 FROM mydb1.t1 +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 24 +USE mydb2 +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +SELECT mydb1.t1.i1 FROM t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 26 +DROP DATABASE mydb1 CASCADE +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +DROP DATABASE mydb2 CASCADE +-- !query 27 schema +struct<> +-- !query 27 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out new file mode 100644 index 000000000000..616421d6f2b2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -0,0 +1,140 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT view1.* FROM view1 +-- !query 1 schema +struct +-- !query 1 output +2 + + +-- !query 2 +SELECT * FROM view1 +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +SELECT view1.i1 FROM view1 +-- !query 3 schema +struct +-- !query 3 output +2 + + +-- !query 4 +SELECT i1 FROM view1 +-- !query 4 schema +struct +-- !query 4 output +2 + + +-- !query 5 +SELECT a.i1 FROM view1 AS a +-- !query 5 schema +struct +-- !query 5 output +2 + + +-- !query 6 +SELECT i1 FROM view1 AS a +-- !query 6 schema +struct +-- !query 6 output +2 + + +-- !query 7 +DROP VIEW view1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1 +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +SELECT * FROM global_temp.view1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT global_temp.view1.* FROM global_temp.view1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'global_temp.view1.*' give input columns 'i1'; + + +-- !query 11 +SELECT i1 FROM global_temp.view1 +-- !query 11 schema +struct +-- !query 11 output +1 + + +-- !query 12 +SELECT global_temp.view1.i1 FROM global_temp.view1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +SELECT view1.i1 FROM global_temp.view1 +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +SELECT a.i1 FROM global_temp.view1 AS a +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +SELECT i1 FROM global_temp.view1 AS a +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +DROP VIEW global_temp.view1 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out new file mode 100644 index 000000000000..764cad0e3943 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -0,0 +1,447 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 54 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +USE mydb1 +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SELECT i1 FROM t1 +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +SELECT i1 FROM mydb1.t1 +-- !query 8 schema +struct +-- !query 8 output +1 + + +-- !query 9 +SELECT t1.i1 FROM t1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT t1.i1 FROM mydb1.t1 +-- !query 10 schema +struct +-- !query 10 output +1 + + +-- !query 11 +SELECT mydb1.t1.i1 FROM t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 12 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1 +-- !query 14 schema +struct +-- !query 14 output +20 + + +-- !query 15 +SELECT i1 FROM mydb1.t1 +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +SELECT t1.i1 FROM t1 +-- !query 16 schema +struct +-- !query 16 output +20 + + +-- !query 17 +SELECT t1.i1 FROM mydb1.t1 +-- !query 17 schema +struct +-- !query 17 output +1 + + +-- !query 18 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 19 +USE mydb1 +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT t1.* FROM t1 +-- !query 20 schema +struct +-- !query 20 output +1 + + +-- !query 21 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 22 +SELECT t1.* FROM mydb1.t1 +-- !query 22 schema +struct +-- !query 22 output +1 + + +-- !query 23 +USE mydb2 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT t1.* FROM t1 +-- !query 24 schema +struct +-- !query 24 output +20 + + +-- !query 25 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 26 +SELECT t1.* FROM mydb1.t1 +-- !query 26 schema +struct +-- !query 26 output +1 + + +-- !query 27 +SELECT a.* FROM mydb1.t1 AS a +-- !query 27 schema +struct +-- !query 27 output +1 + + +-- !query 28 +USE mydb1 +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2) +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2) +-- !query 31 schema +struct +-- !query 31 output +4 1 + + +-- !query 32 +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 + + +-- !query 33 +SET spark.sql.crossJoin.enabled = true +-- !query 33 schema +struct +-- !query 33 output +spark.sql.crossJoin.enabled true + + +-- !query 34 +SELECT mydb1.t1.i1 FROM t1, mydb2.t1 +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 35 +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 36 +USE mydb2 +-- !query 36 schema +struct<> +-- !query 36 output + + + +-- !query 37 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 38 +SET spark.sql.crossJoin.enabled = false +-- !query 38 schema +struct +-- !query 38 output +spark.sql.crossJoin.enabled false + + +-- !query 39 +USE mydb1 +-- !query 39 schema +struct<> +-- !query 39 output + + + +-- !query 40 +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet +-- !query 40 schema +struct<> +-- !query 40 output + + + +-- !query 41 +INSERT INTO t5 VALUES(1, (2, 3)) +-- !query 41 schema +struct<> +-- !query 41 output + + + +-- !query 42 +SELECT t5.i1 FROM t5 +-- !query 42 schema +struct +-- !query 42 output +1 + + +-- !query 43 +SELECT t5.t5.i1 FROM t5 +-- !query 43 schema +struct +-- !query 43 output +2 + + +-- !query 44 +SELECT t5.t5.i1 FROM mydb1.t5 +-- !query 44 schema +struct +-- !query 44 output +2 + + +-- !query 45 +SELECT t5.i1 FROM mydb1.t5 +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +SELECT t5.* FROM mydb1.t5 +-- !query 46 schema +struct> +-- !query 46 output +1 {"i1":2,"i2":3} + + +-- !query 47 +SELECT t5.t5.* FROM mydb1.t5 +-- !query 47 schema +struct +-- !query 47 output +2 3 + + +-- !query 48 +SELECT mydb1.t5.t5.i1 FROM mydb1.t5 +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 49 +SELECT mydb1.t5.t5.i2 FROM mydb1.t5 +-- !query 49 schema +struct<> +-- !query 49 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 50 +SELECT mydb1.t5.* FROM mydb1.t5 +-- !query 50 schema +struct<> +-- !query 50 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; + + +-- !query 51 +USE default +-- !query 51 schema +struct<> +-- !query 51 output + + + +-- !query 52 +DROP DATABASE mydb1 CASCADE +-- !query 52 schema +struct<> +-- !query 52 output + + + +-- !query 53 +DROP DATABASE mydb2 CASCADE +-- !query 53 schema +struct<> +-- !query 53 output + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c64520ff93c8..c0930bbde69a 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -177,7 +177,7 @@ set spark.sql.groupByOrdinal=false -- !query 17 schema struct -- !query 17 output -spark.sql.groupByOrdinal +spark.sql.groupByOrdinal false -- !query 18 diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index de6f01b8de77..4e80f0bda551 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query 0 @@ -143,3 +143,11 @@ struct<> -- !query 15 output org.apache.spark.sql.AnalysisException cannot evaluate expression count(1) in inline table definition; line 1 pos 29 + + +-- !query 16 +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b) +-- !query 16 schema +struct> +-- !query 16 output +1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] diff --git a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out new file mode 100644 index 000000000000..8d56ebe9fd3b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out @@ -0,0 +1,67 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag +-- !query 6 schema +struct +-- !query 6 output +1 a +1 a +1 b +1 b diff --git a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out index 03a4e72d0fa3..cc47cc67c87c 100644 --- a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out @@ -114,7 +114,7 @@ set spark.sql.orderByOrdinal=false -- !query 9 schema struct -- !query 9 output -spark.sql.orderByOrdinal +spark.sql.orderByOrdinal false -- !query 10 diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out index cc50b9444bb4..5db3bae5d037 100644 --- a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out @@ -63,7 +63,7 @@ set spark.sql.crossJoin.enabled = true -- !query 5 schema struct -- !query 5 output -spark.sql.crossJoin.enabled +spark.sql.crossJoin.enabled true -- !query 6 @@ -85,4 +85,4 @@ set spark.sql.crossJoin.enabled = false -- !query 7 schema struct -- !query 7 output -spark.sql.crossJoin.enabled +spark.sql.crossJoin.enabled false diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out index 7258bcfc6ab7..ab6a11a2b7ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out @@ -102,7 +102,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC +ORDER BY t1a DESC, t3b DESC -- !query 4 schema struct -- !query 4 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out index 878bc755ef5f..e06f9206d340 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out @@ -353,7 +353,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC +ORDER BY t1c DESC, t1a DESC -- !query 9 schema struct -- !query 9 output @@ -445,7 +445,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last +ORDER BY t1c DESC NULLS last, t1a DESC -- !query 12 schema struct -- !query 12 output @@ -580,16 +580,16 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST +ORDER BY t1c DESC NULLS LAST, t1i -- !query 15 schema struct -- !query 15 output -1 8 16 2014-05-05 1 8 16 2014-05-04 +1 8 16 2014-05-05 1 16 12 2014-06-04 1 16 12 2014-07-04 1 6 8 2014-04-04 +1 10 NULL 2014-05-04 1 10 NULL 2014-08-04 1 10 NULL 2014-09-04 1 10 NULL 2015-05-04 -1 10 NULL 2014-05-04 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out index db01fa455735..bae5d00cc863 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out @@ -112,12 +112,12 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a +ORDER BY t1a, t3b -- !query 4 schema struct -- !query 4 output -val1c 8 16 1 10 12 val1c 8 16 1 6 12 +val1c 8 16 1 10 12 val1c 8 16 1 17 16 diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv new file mode 100644 index 000000000000..8945ed73d2e8 --- /dev/null +++ b/sql/core/src/test/resources/test-data/value-malformed.csv @@ -0,0 +1,2 @@ +0,2013-111-11 12:13:14 +1,1983-08-04 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1af1a3652971..2a0e088437fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -634,4 +634,20 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedPlan2) == 4) } } + + test("refreshByPath should refresh all cached plans with the specified path") { + withTempDir { dir => + val path = dir.getCanonicalPath() + + spark.range(10).write.mode("overwrite").parquet(path) + spark.read.parquet(path).cache() + spark.read.parquet(path).filter($"id" > 4).cache() + assert(spark.read.parquet(path).filter($"id" > 4).count() == 5) + + spark.range(20).write.mode("overwrite").parquet(path) + spark.catalog.refreshByPath(path) + assert(spark.read.parquet(path).count() == 20) + assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e6338ab7cd80..19c2d5532d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -914,15 +914,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = spark.read.json(sparkContext.makeRDD( - """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) + val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = spark.read.json(sparkContext.makeRDD( - """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) + val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) checkAnswer( df2.select(df2("`a b`.c.d e.f")), Row(1) @@ -1110,8 +1108,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = spark.read.json(sparkContext.makeRDD( - """{"a": {"b": 1}}""" :: Nil)) + val df = spark.read.json(Seq("""{"a": {"b": 1}}""").toDS()) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1164,8 +1161,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = spark.read.json(spark.sparkContext.makeRDD( - (1 to 10).map(i => s"""{"id": $i}"""))) + val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS()) val df = input.select($"id", rand(0).as('r)) df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => @@ -1702,4 +1698,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j") checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f780fc0ec013..2e006735d123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -364,8 +364,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") - val left = UnresolvedRelation(TableIdentifier("left"), None) - val right = UnresolvedRelation(TableIdentifier("right"), None) + val left = UnresolvedRelation(TableIdentifier("left")) + val right = UnresolvedRelation(TableIdentifier("right")) checkAnswer( left.join(right, $"left.N" === $"right.N", "full"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 40d0ce099217..468ea0551298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.sql.Timestamp +import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -209,8 +211,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - spark.read.json(sparkContext.parallelize( - """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + spark.read + .json(Seq("""{"nested": {"attribute": 1}, "value": 2}""").toDS()) .createOrReplaceTempView("rows") checkAnswer( @@ -227,9 +229,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - spark.read.json( - sparkContext.parallelize( - Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + spark.read + .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}").toDS()) .createOrReplaceTempView("d") checkAnswer( @@ -238,9 +239,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - spark.read.json( - sparkContext.parallelize( - Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) + spark.read + .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "").toDS()) .createOrReplaceTempView("d") checkAnswer( @@ -1212,8 +1212,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( - Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = Seq("""{"key?number1": "value1", "key.number2": "value2"}""").toDS() spark.read.json(data).createOrReplaceTempView("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1255,13 +1254,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + spark.read.json(Seq("""{"a": {"b": [{"c": 1}]}}""").toDS()) .createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) spark.catalog.dropTempView("data") - spark.read.json( - sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).createOrReplaceTempView("data") + spark.read.json(Seq("""{"a": {"b": 1}}""").toDS()) + .createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) spark.catalog.dropTempView("data") } @@ -1309,8 +1308,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: ORDER BY test for nested fields") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + spark.read + .json(Seq("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""").toDS()) .createOrReplaceTempView("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1323,7 +1322,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-6145: special cases") { spark.read - .json(sparkContext.makeRDD("""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)) + .json(Seq("""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""").toDS()) .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) @@ -1331,8 +1330,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6898: complete support for special chars in column names") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + spark.read + .json(Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS()) .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) @@ -1435,8 +1434,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempView("t") { - spark.read.json(sparkContext.makeRDD( - """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).createOrReplaceTempView("t") + spark.read + .json(Seq("""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""").toDS()) + .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } @@ -2107,8 +2107,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}}' | """.stripMargin - val rdd = sparkContext.parallelize(Array(json)) - spark.read.json(rdd).write.mode("overwrite").parquet(dir.toString) + spark.read.json(Seq(json).toDS()).write.mode("overwrite").parquet(dir.toString) spark.read.parquet(dir.toString).collect() } } @@ -2564,4 +2563,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql(badQuery), Row(1) :: Nil) } + test("SPARK-19650: An action on a Command should not trigger a Spark job") { + // Create a listener that checks if new jobs have started. + val jobStarted = new AtomicBoolean(false) + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobStarted.set(true) + } + } + + // Make sure no spurious job starts are pending in the listener bus. + sparkContext.listenerBus.waitUntilEmpty(500) + sparkContext.addSparkListener(listener) + try { + // Execute the command. + sql("show databases").head() + + // Make sure we have seen all events triggered by DataFrame.show() + sparkContext.listenerBus.waitUntilEmpty(500) + } finally { + sparkContext.removeSparkListener(listener) + } + assert(!jobStarted.get(), "Command should not trigger a Spark job.") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 91aecca537fb..68ababcd1102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -98,7 +98,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** List of test cases to ignore, in lower cases. */ private val blackList = Set( - "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + "blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality. + ".DS_Store" // A meta-file that may be created on Mac by Finder App. + // We should ignore this file from processing. ) // Create all the test cases. @@ -121,7 +123,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.contains(testCase.name.toLowerCase)) { + if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { @@ -226,12 +228,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { - case a: AnalysisException if a.plan.nonEmpty => + case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. // Also implement a crude way of masking expression IDs in the error message // with a generic pattern "###". - (StructType(Seq.empty), - Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x"))) + val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage + (StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) @@ -241,7 +243,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def listTestCases(): Seq[TestCase] = { listFilesRecursively(new File(inputFilePath)).map { file => val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" - TestCase(file.getName, file.getAbsolutePath, resultFile) + val absPath = file.getAbsolutePath + val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) + TestCase(testCaseName, absPath, resultFile) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bd1ce8aa3eb1..bbb31dbc8f3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -170,6 +170,27 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) } + + test("number format in statistics") { + val numbers = Seq( + BigInt(0) -> ("0.0 B", "0"), + BigInt(100) -> ("100.0 B", "100"), + BigInt(2047) -> ("2047.0 B", "2.05E+3"), + BigInt(2048) -> ("2.0 KB", "2.05E+3"), + BigInt(3333333) -> ("3.2 MB", "3.33E+6"), + BigInt(4444444444L) -> ("4.1 GB", "4.44E+9"), + BigInt(5555555555555L) -> ("5.1 TB", "5.56E+12"), + BigInt(6666666666666666L) -> ("5.9 PB", "6.67E+15"), + BigInt(1L << 10 ) * (1L << 60) -> ("1024.0 EB", "1.18E+21"), + BigInt(1L << 11) * (1L << 60) -> ("2.36E+21 B", "2.36E+21") + ) + numbers.foreach { case (input, (expectedSize, expectedRows)) => + val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) + val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + + s" isBroadcastable=${stats.isBroadcastable}" + assert(stats.simpleString == expectedString) + } + } } @@ -285,7 +306,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils // Analyze only one column. sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.catalogTable) + case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index c7a77daacab7..b096a6db8517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -221,8 +221,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT StructField("vec", new UDT.MyDenseVectorUDT, false) )) - val stringRDD = sparkContext.parallelize(data) - val jsonRDD = spark.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(data.toDS()) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -242,8 +241,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT StructField("vec", new UDT.MyDenseVectorUDT, false) )) - val stringRDD = sparkContext.parallelize(data) - val jsonDataset = spark.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(data.toDS()) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 15e490fb30a2..bb6c486e880a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -36,7 +38,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends PlanTest { - private lazy val parser = new SparkSqlParser(new SQLConf) + val newConf = new SQLConf + private lazy val parser = new SparkSqlParser(newConf) /** * Normalizes plans: @@ -251,4 +254,29 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("t"))) + + assertEqual(s"$baseSql distribute by a, b", + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions)) + assertEqual(s"$baseSql distribute by a sort by b", + Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + assertEqual(s"$baseSql cluster by a, b", + Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 3988d9750b58..239822b72034 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -73,13 +73,13 @@ object TPCDSQueryBenchmark { // per-row processing time for those cases. val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.logical.map { - case ur @ UnresolvedRelation(t: TableIdentifier, _) => + case ur @ UnresolvedRelation(t: TableIdentifier) => queryRelations.add(t.table) case lp: LogicalPlan => lp.expressions.foreach { _ foreach { case subquery: SubqueryExpression => subquery.plan.foreach { - case ur @ UnresolvedRelation(t: TableIdentifier, _) => + case ur @ UnresolvedRelation(t: TableIdentifier) => queryRelations.add(t.table) case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index d2704b3d3f37..a42891e55a18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -140,7 +140,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } datum += "}" datum = s"""{"a": {"b": {"c": $datum, "d": $datum}, "e": $datum}}""" - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$width wide x $numRows rows", "a.b.c.value_1") } @@ -157,7 +157,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { datum = "{\"value\": " + datum + "}" selector = selector + ".value" } - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$depth deep x $numRows rows", selector) } @@ -180,7 +180,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } // TODO(ekl) seems like the json parsing is actually the majority of the time, perhaps // we should benchmark that too separately. - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$numNodes x $depth deep x $numRows rows", selector) } @@ -200,7 +200,7 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } } datum += "]}" - val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum)).cache() df.count() // force caching addCases(benchmark, df, s"$width wide x $numRows rows", "value[0]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e1a3b247fd4f..8b8cd0fdf4db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1520,7 +1520,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tab1 USING json") }.getMessage assert(e.contains("Unable to infer schema for JSON. It must be specified manually")) - sql(s"CREATE TABLE tab2 using json location '${tempDir.getCanonicalPath}'") + sql(s"CREATE TABLE tab2 using json location '${tempDir.toURI}'") checkAnswer(spark.table("tab2"), Row("a", "b")) } } @@ -1814,7 +1814,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - sql(s"ALTER TABLE tbl SET LOCATION '${dir.getCanonicalPath}'") + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") spark.catalog.refreshTable("tbl") // SET LOCATION won't move data from previous table path to new table path. assert(spark.table("tbl").count() == 0) @@ -1843,11 +1843,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == dir.getAbsolutePath) dir.delete - val tableLocFile = new File(table.location.stripPrefix("file:")) + val tableLocFile = new File(table.location) assert(!tableLocFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") assert(tableLocFile.exists) @@ -1859,8 +1858,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(tableLocFile.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) - val newDir = dir.getAbsolutePath.stripSuffix("/") + "/x" - val newDirFile = new File(newDir) + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI.toString spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) @@ -1886,8 +1885,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == dir.getAbsolutePath) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1913,18 +1911,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == dir.getAbsolutePath) dir.delete() checkAnswer(spark.table("t"), Nil) - val newDir = dir.getAbsolutePath.stripSuffix("/") + "/x" + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI.toString spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table1.location == newDir) - assert(!new File(newDir).exists()) + assert(!newDirFile.exists()) checkAnswer(spark.table("t"), Nil) } } @@ -1951,4 +1949,51 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == dir.getAbsolutePath) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == dir.getAbsolutePath) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 2b4c9f3ed327..7ea406492757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileIndexSuite extends SharedSQLContext { @@ -178,6 +179,47 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog2.allFiles().nonEmpty) } } + + test("InMemoryFileIndex with empty rootPaths when PARALLEL_PARTITION_DISCOVERY_THRESHOLD" + + "is a nonpositive number") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "0") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "-1") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + }.getMessage + assert(e.contains("The maximum number of paths allowed for listing files at " + + "driver side must not be negative")) + } + + test("refresh for InMemoryFileIndex with FileStatusCache") { + withTempDir { dir => + val fileStatusCache = FileStatusCache.getOrCreate(spark) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val catalog = + new InMemoryFileIndex(spark, Seq(dirPath), Map.empty, None, fileStatusCache) { + def leafFilePaths: Seq[Path] = leafFiles.keys.toSeq + def leafDirPaths: Seq[Path] = leafDirToChildrenFiles.keys.toSeq + } + + val file = new File(dir, "text.txt") + stringToFile(file, "text") + assert(catalog.leafDirPaths.isEmpty) + assert(catalog.leafFilePaths.isEmpty) + + catalog.refresh() + + assert(catalog.leafFilePaths.size == 1) + assert(catalog.leafFilePaths.head == fs.makeQualified(new Path(file.getAbsolutePath))) + + assert(catalog.leafDirPaths.size == 1) + assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index b0cc5933601c..90a12933e5f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -24,11 +24,12 @@ import java.text.SimpleDateFormat import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -53,8 +54,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val numbersFile = "test-data/numbers.csv" private val datesFile = "test-data/dates.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" + private val valueMalformedFile = "test-data/value-malformed.csv" private val filenameSpecialChr = "filename19340*.csv" + private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString } @@ -243,12 +246,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - val cars = spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { wholeFile => + val cars = spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } test("test for blank column names on read and select columns") { @@ -263,14 +269,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - val exception = intercept[SparkException]{ - spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "failfast")) - .load(testFile(carsFile)).collect() - } + Seq(false, true).foreach { wholeFile => + val exception = intercept[SparkException] { + spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } } test("test for tokens more than the fields in the schema") { @@ -701,12 +710,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[SparkException] { + msg = intercept[UnsupportedOperationException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() - }.getCause.getMessage - assert(msg.contains("Unsupported type: array")) + }.getMessage + assert(msg.contains("CSV data source does not support array data type.")) } } @@ -735,10 +744,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601timestampsPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601Timestamps = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601timestampsPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) @@ -768,10 +778,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601datesPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601dates = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601datesPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) @@ -826,10 +837,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(datesWithFormatPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringDatesWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(datesWithFormatPath) val expectedStringDatesWithFormat = Seq( Row("2015/08/26"), @@ -857,10 +869,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/26 18:00"), @@ -889,10 +902,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/27 01:00"), @@ -960,6 +974,125 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + + test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { + Seq(false, true).foreach { wholeFile => + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .option("wholeFile", wholeFile) + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField1) + .csv(testFile(valueMalformedFile)) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField2) + .csv(testFile(valueMalformedFile)) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } + + test("SPARK-19610: Parse normal multi-line CSV files") { + val primitiveFieldAndType = Seq( + """" + |string","integer + | + | + |","long + | + |","bigInteger",double,boolean,null""".stripMargin, + """"this is a + |simple + |string."," + | + |10"," + |21474836470","92233720368547758070"," + | + |1.7976931348623157E308",true,""".stripMargin) + + withTempPath { path => + primitiveFieldAndType.toDF("value").coalesce(1).write.text(path.getAbsolutePath) + + val df = spark.read + .option("header", true) + .option("wholeFile", true) + .csv(path.getAbsolutePath) + + // Check if headers have new lines in the names. + val actualFields = df.schema.fieldNames.toSeq + val expectedFields = + Seq("\nstring", "integer\n\n\n", "long\n\n", "bigInteger", "double", "boolean", "null") + assert(actualFields === expectedFields) + + // Check if the rows have new lines in the values. + val expected = Row( + "this is a\nsimple\nstring.", + "\n\n10", + "\n21474836470", + "92233720368547758070", + "\n\n1.7976931348623157E308", + "true", + null) + checkAnswer(df, expected) + } + } + + test("Empty file produces empty dataframe with empty schema - wholeFile option") { + withTempPath { path => + path.createNewFile() + + val df = spark.read.format("csv") + .option("header", true) + .option("wholeFile", true) + .load(path.getAbsolutePath) + + assert(df.schema === spark.emptyDataFrame.schema) + checkAnswer(df, spark.emptyDataFrame) + } + test("SPARK-19340 special characters in csv file name") { val csvDF = spark.read .option("header", "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 0b72da5f3759..6e2b4f0df595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -25,19 +25,18 @@ import org.apache.spark.sql.test.SharedSQLContext * Test cases for various [[JSONOptions]]. */ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("allowComments off") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowComments on") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowComments", "true").json(rdd) + val df = spark.read.option("allowComments", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -45,16 +44,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowSingleQuotes off") { val str = """{'name': 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowSingleQuotes", "false").json(rdd) + val df = spark.read.option("allowSingleQuotes", "false").json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowSingleQuotes on") { val str = """{'name': 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -62,16 +59,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowUnquotedFieldNames off") { val str = """{name: 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowUnquotedFieldNames on") { val str = """{name: 'Reynold Xin'}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd) + val df = spark.read.option("allowUnquotedFieldNames", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -79,16 +74,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowNumericLeadingZeros on") { val str = """{"age": 0018}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd) + val df = spark.read.option("allowNumericLeadingZeros", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "age") assert(df.first().getLong(0) == 18) @@ -98,16 +91,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. ignore("allowNonNumericNumbers off") { val str = """{"age": NaN}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } ignore("allowNonNumericNumbers on") { val str = """{"age": NaN}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd) + val df = spark.read.option("allowNonNumericNumbers", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) @@ -115,16 +106,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(Seq(str).toDS()) assert(df.schema.head.name == "_corrupt_record") } test("allowBackslashEscapingAnyCharacter on") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = spark.sparkContext.parallelize(Seq(str)) - val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(Seq(str).toDS()) assert(df.schema.head.name == "name") assert(df.schema.last.name == "price") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 05aa2ab2ce2d..0aaf148dac25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -590,7 +590,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val expectedSchema = StructType( @@ -622,7 +622,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.option("primitivesAsString", "true").json(path) val expectedSchema = StructType( @@ -777,9 +777,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { - val mixedIntegerAndDoubleRecords = sparkContext.parallelize( - """{"a": 3, "b": 1.1}""" :: - s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) + val mixedIntegerAndDoubleRecords = Seq( + """{"a": 3, "b": 1.1}""", + s"""{"a": 3.1, "b": 0.${"0" * 38}1}""").toDS() val jsonDF = spark.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) @@ -828,7 +828,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val mergedJsonDF = spark.read .option("prefersDecimal", "true") - .json(floatingValueRecords ++ bigIntegerRecords) + .json(floatingValueRecords.union(bigIntegerRecords)) val expectedMergedSchema = StructType( StructField("a", DoubleType, true) :: @@ -846,7 +846,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.toURI.toString - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) sql( s""" @@ -873,7 +873,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val schema = StructType( StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: @@ -1263,7 +1263,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") val jsonDF = spark.read.json(primitiveFieldAndType) - val primTable = spark.read.json(jsonDF.toJSON.rdd) + val primTable = spark.read.json(jsonDF.toJSON) primTable.createOrReplaceTempView("primitiveTable") checkAnswer( sql("select * from primitiveTable"), @@ -1276,7 +1276,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) val complexJsonDF = spark.read.json(complexFieldAndType1) - val compTable = spark.read.json(complexJsonDF.toJSON.rdd) + val compTable = spark.read.json(complexJsonDF.toJSON) compTable.createOrReplaceTempView("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1364,10 +1364,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { }) } - test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { + test("SPARK-6245 JsonInferSchema.infer on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, + empty.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1394,7 +1394,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, + emptyRecords.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1592,7 +1592,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - arrayAndStructRecords.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + arrayAndStructRecords.map(record => record.replaceAll("\n", " ")).write.text(path) val schema = StructType( @@ -1609,7 +1609,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath @@ -1645,7 +1645,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).write.text(path) val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath @@ -1693,8 +1693,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val json = s""" |{"a": [{$nested}], "b": [{$nested}]} """.stripMargin - val rdd = spark.sparkContext.makeRDD(Seq(json)) - val df = spark.read.json(rdd) + val df = spark.read.json(Seq(json).toDS()) assert(df.schema.size === 2) df.collect() } @@ -1794,8 +1793,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val records = sparkContext - .parallelize("""{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 0.000001}""" :: Nil) + val records = Seq("""{"a": 3, "b": 1.1}""", """{"a": 3.1, "b": 0.000001}""").toDS() val schema = StructType( StructField("a", DecimalType(21, 1), true) :: @@ -1944,4 +1942,35 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) } } + + test("Throw an exception if a `columnNameOfCorruptRecord` field violates requirements") { + val columnNameOfCorruptRecord = "_unparsed" + val schema = StructType( + StructField(columnNameOfCorruptRecord, IntegerType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(corruptRecords) + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + + withTempPath { dir => + val path = dir.getCanonicalPath + corruptRecords.toDF("value").write.text(path) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(path) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index a400940db924..13084ba4a7f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.json -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} private[json] trait TestJsonData { protected def spark: SparkSession - def primitiveFieldAndType: RDD[String] = - spark.sparkContext.parallelize( + def primitiveFieldAndType: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,10 +31,10 @@ private[json] trait TestJsonData { "double":1.7976931348623157E308, "boolean":true, "null":null - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def primitiveFieldValueTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -44,16 +43,17 @@ private[json] trait TestJsonData { "num_bool":false, "num_str":"str1", "str_bool":false}""" :: """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) + )(Encoders.STRING) - def jsonNullStruct: RDD[String] = - spark.sparkContext.parallelize( + def jsonNullStruct: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: - """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) + """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil))(Encoders.STRING) - def complexFieldValueTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldValueTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -62,24 +62,25 @@ private[json] trait TestJsonData { "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) + )(Encoders.STRING) - def arrayElementTypeConflict: RDD[String] = - spark.sparkContext.parallelize( + def arrayElementTypeConflict: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: - """{"array3": [1, 2, 3]}""" :: Nil) + """{"array3": [1, 2, 3]}""" :: Nil))(Encoders.STRING) - def missingFields: RDD[String] = - spark.sparkContext.parallelize( + def missingFields: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: - """{"e":"str"}""" :: Nil) + """{"e":"str"}""" :: Nil))(Encoders.STRING) - def complexFieldAndType1: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldAndType1: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,10 +93,10 @@ private[json] trait TestJsonData { "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def complexFieldAndType2: RDD[String] = - spark.sparkContext.parallelize( + def complexFieldAndType2: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,89 +147,90 @@ private[json] trait TestJsonData { {"inner3": [[{"inner4": 2}]]} ] ]] - }""" :: Nil) + }""" :: Nil))(Encoders.STRING) - def mapType1: RDD[String] = - spark.sparkContext.parallelize( + def mapType1: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: - """{"map": {"e": null}}""" :: Nil) + """{"map": {"e": null}}""" :: Nil))(Encoders.STRING) - def mapType2: RDD[String] = - spark.sparkContext.parallelize( + def mapType2: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: """{"map": {"e": null}}""" :: - """{"map": {"f": {"field1": null}}}""" :: Nil) + """{"map": {"f": {"field1": null}}}""" :: Nil))(Encoders.STRING) - def nullsInArrays: RDD[String] = - spark.sparkContext.parallelize( + def nullsInArrays: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: - """{"field4":[[null, [1,2,3]]]}""" :: Nil) + """{"field4":[[null, [1,2,3]]]}""" :: Nil))(Encoders.STRING) - def jsonArray: RDD[String] = - spark.sparkContext.parallelize( + def jsonArray: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """[]""" :: Nil) + """[]""" :: Nil))(Encoders.STRING) - def corruptRecords: RDD[String] = - spark.sparkContext.parallelize( + def corruptRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: """{"a":{, b:3}""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """]""" :: Nil) + """]""" :: Nil))(Encoders.STRING) - def additionalCorruptRecords: RDD[String] = - spark.sparkContext.parallelize( + def additionalCorruptRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"dummy":"test"}""" :: """[1,2,3]""" :: """":"test", "a":1}""" :: """42""" :: - """ ","ian":"test"}""" :: Nil) + """ ","ian":"test"}""" :: Nil))(Encoders.STRING) - def emptyRecords: RDD[String] = - spark.sparkContext.parallelize( + def emptyRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: """{"a": {"b": {}}}""" :: """{"b": [{"c": {}}]}""" :: - """]""" :: Nil) + """]""" :: Nil))(Encoders.STRING) - def timestampAsLong: RDD[String] = - spark.sparkContext.parallelize( - """{"ts":1451732645}""" :: Nil) + def timestampAsLong: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + """{"ts":1451732645}""" :: Nil))(Encoders.STRING) - def arrayAndStructRecords: RDD[String] = - spark.sparkContext.parallelize( + def arrayAndStructRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"a": {"b": 1}}""" :: - """{"a": []}""" :: Nil) + """{"a": []}""" :: Nil))(Encoders.STRING) - def floatingValueRecords: RDD[String] = - spark.sparkContext.parallelize( - s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) + def floatingValueRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil))(Encoders.STRING) - def bigIntegerRecords: RDD[String] = - spark.sparkContext.parallelize( - s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) + def bigIntegerRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil))(Encoders.STRING) - def datesRecords: RDD[String] = - spark.sparkContext.parallelize( + def datesRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( """{"date": "26/08/2015 18:00"}""" :: """{"date": "27/10/2014 18:30"}""" :: - """{"date": "28/01/2016 20:00"}""" :: Nil) + """{"date": "28/01/2016 20:00"}""" :: Nil))(Encoders.STRING) - lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) - def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) + def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 420cff878fa0..88cb8a0bad21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -51,9 +52,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME + val timeZone = TimeZone.getDefault() + val timeZoneId = timeZone.getID + test("column type inference") { - def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, true) === literal) + def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { + assert(inferPartitionColumnValue(raw, true, timeZone) === literal) } check("10", Literal.create(10, IntegerType)) @@ -66,6 +70,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType)) check("1990-02-24 12:00:30", Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType)) + + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + c.set(1990, 1, 24, 12, 0, 30) + c.set(Calendar.MILLISECOND, 0) + check("1990-02-24 12:00:30", + Literal.create(new Timestamp(c.getTimeInMillis), TimestampType), + TimeZone.getTimeZone("GMT")) + check(defaultPartitionName, Literal.create(null, NullType)) } @@ -77,7 +89,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -90,7 +102,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) // Valid paths = Seq( @@ -102,7 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/something=true/table"))) + Set(new Path("hdfs://host:9000/path/something=true/table")), + timeZoneId) // Valid paths = Seq( @@ -114,7 +128,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/table=true"))) + Set(new Path("hdfs://host:9000/path/table=true")), + timeZoneId) // Invalid paths = Seq( @@ -126,7 +141,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -145,20 +161,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/tmp/tables/"))) + Set(new Path("hdfs://host:9000/tmp/tables/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path], timeZone) }.getMessage assert(message.contains(expected)) @@ -201,7 +218,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path/a=10")))._1 + basePaths = Set(new Path("file://path/a=10")), + timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -209,7 +227,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path")))._1 + basePaths = Set(new Path("file://path")), + timeZone = timeZone)._1 assert(partitionSpec2 == Option(PartitionValues( @@ -226,7 +245,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - rootPaths) + rootPaths, + timeZoneId) assert(actualSpec === spec) } @@ -307,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId) assert(actualSpec === spec) } @@ -686,6 +706,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name).cast(f.dataType)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("Various inferred partition value types") { @@ -720,6 +747,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("SPARK-8037: Ignores files whose name starts with dot") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index d7d7176c48a3..200e356c72fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -77,8 +77,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val df = spark.read.parquet(path).cache() assert(df.count() == 1000) spark.range(10).write.mode("overwrite").parquet(path) - assert(df.count() == 1000) - spark.catalog.refreshByPath(path) assert(df.count() == 10) assert(spark.read.parquet(path).count() == 10) } @@ -91,8 +89,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val df = spark.read.parquet(path).cache() assert(df.count() == 1000) spark.range(10).write.mode("append").parquet(path) - assert(df.count() == 1000) - spark.catalog.refreshByPath(path) assert(df.count() == 1010) assert(spark.read.parquet(path).count() == 1010) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6b38b6a09721..e848f74e3159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -210,13 +212,6 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - - // Overwrite the version with other data - val store2 = provider.getStore(1) - put(store2, "c", 1) - assert(store2.commit() === 2) - assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) } test("snapshotting") { @@ -292,6 +287,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } + test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + val conf = new Configuration() + conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) + conf.set("fs.default.name", "fake:///") + + val provider = newStoreProvider(hadoopConf = conf) + provider.getStore(0).commit() + provider.getStore(0).commit() + + // Verify we don't leak temp files + val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + null, true).asScala.filter(_.getName.startsWith("temp-")) + assert(tempFiles.isEmpty) + } test("corrupted file handling") { val provider = newStoreProvider(minDeltasForSnapshot = 5) @@ -681,6 +690,21 @@ private[state] object StateStoreSuite { } } +/** + * Fake FileSystem that simulates HDFS rename semantic, i.e. renaming a file atop an existing + * one should return false. + * See hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html + */ +class RenameLikeHDFSFileSystem extends RawLocalFileSystem { + override def rename(src: Path, dst: Path): Boolean = { + if (exists(dst)) { + return false + } else { + return super.rename(src, dst) + } + } +} + /** * Fake FileSystem to test that the StateStore throws an exception while committing the * delta file, when `fs.rename` returns `false`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 8aea112897fb..e41c00ecec27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -147,6 +147,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + // Driver accumulator updates don't belong to this execution should be filtered and no + // exception will be thrown. + listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala similarity index 59% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index d9ddcbd57ca8..9b65419dba23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -29,17 +29,25 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet -class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class BucketedReadWithoutHiveSupportSuite extends BucketedReadSuite with SharedSQLContext { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + } +} + + +abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { import testImplicits._ - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - private val nullDF = (for { + private lazy val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private lazy val nullDF = (for { i <- 0 to 50 s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") @@ -224,8 +232,16 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } - private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") - private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + private lazy val df1 = + (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + private lazy val df2 = + (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + + case class BucketedTableTestSpec( + bucketSpec: Option[BucketSpec], + numPartitions: Int = 10, + expectedShuffle: Boolean = true, + expectedSort: Boolean = true) /** * A helper method to test the bucket read functionality using join. It will save `df1` and `df2` @@ -234,14 +250,15 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet * exists as user expected according to the `shuffleLeft` and `shuffleRight`. */ private def testBucketing( - bucketSpecLeft: Option[BucketSpec], - bucketSpecRight: Option[BucketSpec], + bucketedTableTestSpecLeft: BucketedTableTestSpec, + bucketedTableTestSpecRight: BucketedTableTestSpec, joinType: String = "inner", - joinCondition: (DataFrame, DataFrame) => Column, - shuffleLeft: Boolean, - shuffleRight: Boolean, - sortLeft: Boolean = true, - sortRight: Boolean = true): Unit = { + joinCondition: (DataFrame, DataFrame) => Column): Unit = { + val BucketedTableTestSpec(bucketSpecLeft, numPartitionsLeft, shuffleLeft, sortLeft) = + bucketedTableTestSpecLeft + val BucketedTableTestSpec(bucketSpecRight, numPartitionsRight, shuffleRight, sortRight) = + bucketedTableTestSpecRight + withTable("bucketed_table1", "bucketed_table2") { def withBucket( writer: DataFrameWriter[Row], @@ -263,8 +280,10 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet }.getOrElse(writer) } - withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") - withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") + withBucket(df1.repartition(numPartitionsLeft).write.format("parquet"), bucketSpecLeft) + .saveAsTable("bucketed_table1") + withBucket(df2.repartition(numPartitionsRight).write.format("parquet"), bucketSpecRight) + .saveAsTable("bucketed_table2") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { @@ -291,10 +310,10 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet // check existence of sort assert( joinOperator.left.find(_.isInstanceOf[SortExec]).isDefined == sortLeft, - s"expected sort in plan to be $shuffleLeft but found\n${joinOperator.left}") + s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}") assert( joinOperator.right.find(_.isInstanceOf[SortExec]).isDefined == sortRight, - s"expected sort in plan to be $shuffleRight but found\n${joinOperator.right}") + s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") } } } @@ -305,138 +324,174 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("avoid shuffle when join 2 bucketed tables") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when join bucketed table and non-bucketed table") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = None, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when 2 bucketed tables have different bucket number") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) - val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketSpecRight = Some(BucketSpec(5, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) - val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketSpecRight = Some(BucketSpec(8, Seq("j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i")) ) } test("shuffle when join keys are not equal to bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("j")), - shuffleLeft = true, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("j")) ) } test("shuffle when join 2 bucketed tables with bucketing disabled") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = true, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } } - test("avoid shuffle and sort when bucket and sort columns are join keys") { + test("check sort and shuffle when bucket and sort columns are join keys") { + // In case of bucketing, its possible to have multiple files belonging to the + // same bucket in a given relation. Each of these files are locally sorted + // but those files combined together are not globally sorted. Given that, + // the RDD partition will not be sorted even if the relation has sort columns set + // Therefore, we still need to keep the Sort in both sides. val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + + val bucketedTableTestSpecLeft1 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + val bucketedTableTestSpecRight1 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft1, + bucketedTableTestSpecRight = bucketedTableTestSpecRight1, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft2 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight2 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft2, + bucketedTableTestSpecRight = bucketedTableTestSpecRight2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft3 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + val bucketedTableTestSpecRight3 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft3, + bucketedTableTestSpecRight = bucketedTableTestSpecRight3, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft4 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight4 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft4, + bucketedTableTestSpecRight = bucketedTableTestSpecRight4, + joinCondition = joinCondition(Seq("i", "j")) ) } test("avoid shuffle and sort when sort columns are a super set of join keys") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i")) ) } test("only sort one side when sort columns are different") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only sort one side when sort columns are same but their ordering is different") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } @@ -470,27 +525,27 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("SPARK-17698 Join predicates should not contain filter clauses") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, joinType = "fullouter", joinCondition = (left: DataFrame, right: DataFrame) => { val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _) val filterLeft = left("i") === Literal("1") val filterRight = right("i") === Literal("1") joinPredicates && filterLeft && filterRight - }, - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + } ) } test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val warehouseFilePath = new URI(hiveContext.sparkSession.getWarehousePath).getPath + val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath val tableDir = new File(warehouseFilePath, "bucketed_table") Utils.deleteRecursively(tableDir) df1.write.parquet(tableDir.getAbsolutePath) @@ -508,9 +563,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - checkAnswer(hiveContext.table("bucketed_table").select("j"), df1.select("j")) + checkAnswer(spark.table("bucketed_table").select("j"), df1.select("j")) - checkAnswer(hiveContext.table("bucketed_table").groupBy("j").agg(max("k")), + checkAnswer(spark.table("bucketed_table").groupBy("j").agg(max("k")), df1.groupBy("j").agg(max("k"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala similarity index 88% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 61cef2a8008f..9082261af7b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,19 +20,29 @@ package org.apache.spark.sql.sources import java.io.File import java.net.URI -import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class BucketedWriteWithoutHiveSupportSuite extends BucketedWriteSuite with SharedSQLContext { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + } + + override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "json") +} + +abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { import testImplicits._ + protected def fileFormatsToTest: Seq[String] + test("bucketed by non-existing column") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) @@ -76,11 +86,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(e.getMessage == "'insertInto' does not support bucketing right now;") } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private lazy val df = { + (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + } def tableDir: File = { val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") - new File(URI.create(hiveContext.sessionState.catalog.hiveDefaultTableFilePath(identifier))) + new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier))) } /** @@ -141,7 +153,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -157,7 +169,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data with sortBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -190,7 +202,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data without partitionBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -203,7 +215,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data without partitionBy with sortBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -219,7 +231,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data with bucketing disabled") { // The configuration BUCKETING_ENABLED does not affect the writing path withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 4a42f8ea79cf..916a01ee0ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -33,14 +33,15 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 4fc2f81f540b..2eae66dda88d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { + import testImplicits._ + protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") sql( s""" |CREATE TEMPORARY VIEW jsonTable (a int, b string) @@ -129,7 +131,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) - spark.read.json(rdd1).createOrReplaceTempView("jt1") + spark.read.json(rdd1.toDS()).createOrReplaceTempView("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -141,7 +143,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) - spark.read.json(rdd2).createOrReplaceTempView("jt2") + spark.read.json(rdd1.toDS()).createOrReplaceTempView("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -279,15 +281,15 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { """.stripMargin) // jsonTable should be recached. assertCached(sql("SELECT * FROM jsonTable")) - // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation -// // The cached data is the new data. -// checkAnswer( -// sql("SELECT a, b FROM jsonTable"), -// sql("SELECT a * 2, b FROM jt").collect()) -// -// // Verify uncaching -// spark.catalog.uncacheTable("jsonTable") -// assertCached(sql("SELECT * FROM jsonTable"), 0) + + // The cached data is the new data. + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a * 2, b FROM jt").collect()) + + // Verify uncaching + spark.catalog.uncacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index bf7fabe33266..f251290583c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -124,6 +126,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { } } + test("timeZone setting in dynamic partition writes") { + def checkPartitionValues(file: File, expected: String): Unit = { + val dir = file.getParentFile() + val value = ExternalCatalogUtils.unescapePathName( + dir.getName.substring(dir.getName.indexOf("=") + 1)) + assert(value == expected) + } + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((1, ts)).toDF("i", "ts") + withTempPath { f => + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + checkPartitionValues(files.head, "2016-12-01 00:00:00") + } + withTempPath { f => + df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // use timeZone option "GMT" to format partition value. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + withTempPath { f => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // if there isn't timeZone option, then use session local timezone. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + } + } + /** Lists files recursively. */ private def recursiveList(f: File): Array[File] = { require(f.isDirectory) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index b1756c27fae0..773d34dfaf9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + import testImplicits._ + protected override lazy val sql = spark.sql _ private var originalDefaultSource: String = null private var path: File = null @@ -40,8 +42,8 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA path = Utils.createTempDir() path.delete() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = spark.read.json(rdd) + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""").toDS() + df = spark.read.json(ds) df.createOrReplaceTempView("jsonTable") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala new file mode 100644 index 000000000000..7ea716231e5d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ + +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("deduplicate with all columns") { + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch("a"), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a"), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b"), + CheckLastBatch("b"), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("deduplicate with some columns") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a" -> 2), // Dropped + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("multiple deduplicates") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + + AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates` + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)), + + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with watermark") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(10 to 15: _*), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(25), + assertNumStateRows(total = 7, updated = 1), + + AddData(inputData, 25), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 45), // Advance watermark to 35 seconds + CheckLastBatch(45), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, 45), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0) + ) + } + + test("deduplicate with aggregate - append mode") { + val inputData = MemoryStream[Int] + val windowedaggregate = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedaggregate)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) (2 windows) + // states in deduplicate is 10 to 15 + assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) + // states in deduplicate is 10 to 15 and 25 + assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), + + AddData(inputData, 25), // Emit items less than watermark and drop their state + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate + // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of + // window to evict items, so [15, 20) is still in the state store) + // states in deduplicate is 25 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 40), // Advance watermark to 30 seconds + CheckLastBatch(), + // states in aggregate in [15, 20), [25, 30) and [40, 45) + // states in deduplicate is 25 and 40, + assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), + + AddData(inputData, 40), // Emit items less than watermark and drop their state + CheckLastBatch((15 -> 1), (25 -> 1)), + // states in aggregate in [40, 45) + // states in deduplicate is 40, + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + ) + } + + test("deduplicate with aggregate - update mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Update)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with aggregate - complete mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Complete)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("a" -> 3L, "b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with file sink") { + withTempDir { output => + withTempDir { checkpointDir => + val outputPath = output.getAbsolutePath + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + val q = result.writeStream + .format("parquet") + .outputMode(Append) + .option("checkpointLocation", checkpointDir.getPath) + .start(outputPath) + try { + inputData.addData("a") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("a") // Dropped + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("b") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a", "b") + } finally { + q.stop() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5110d89c85b1..1586850c77fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -52,10 +52,7 @@ abstract class FileStreamSourceTest query.nonEmpty, "Cannot add data when there is no query for finding the active file stream source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - } + val sources = getSourcesFromStreamingQuery(query.get) if (sources.isEmpty) { throw new Exception( "Could not find file source in the StreamExecution logical plan to add data to") @@ -134,6 +131,14 @@ abstract class FileStreamSourceTest }.head } + protected def getSourcesFromStreamingQuery(query: StreamExecution): Seq[FileStreamSource] = { + query.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => + source.asInstanceOf[FileStreamSource] + } + } + + protected def withTempDirs(body: (File, File) => Unit) { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") @@ -388,9 +393,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("a", "b", "c", "d"), AssertOnQuery("seen files should contain only one entry") { streamExecution => - val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => - e.source.asInstanceOf[FileStreamSource] - }.head + val source = getSourcesFromStreamingQuery(streamExecution).head assert(source.seenFiles.size == 1) true } @@ -662,6 +665,101 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("read data from outputs of another streaming query") { + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + val q1 = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + .start(outputDir.getCanonicalPath) + + // q2 is a streaming query that reads q1's text outputs + val q2 = + createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + def q1AddData(data: String*): StreamAction = + Execute { _ => + q1Source.addData(data) + q1.processAllAvailable() + } + def q2ProcessAllAvailable(): StreamAction = Execute { q2 => q2.processAllAvailable() } + + testStream(q2)( + // batch 0 + q1AddData("drop1", "keep2"), + q2ProcessAllAvailable(), + CheckAnswer("keep2"), + + // batch 1 + Assert { + // create a text file that won't be on q1's sink log + // thus even if its content contains "keep", it should NOT appear in q2's answer + val shouldNotKeep = new File(outputDir, "should_not_keep.txt") + stringToFile(shouldNotKeep, "should_not_keep!!!") + shouldNotKeep.exists() + }, + q1AddData("keep3"), + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3"), + + // batch 2: check that things work well when the sink log gets compacted + q1AddData("keep4"), + Assert { + // compact interval is 3, so file "2.compact" should exist + new File(outputDir, s"${FileStreamSink.metadataDir}/2.compact").exists() + }, + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3", "keep4"), + + Execute { _ => q1.stop() } + ) + } + } + } + + test("start before another streaming query, and read its output") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + // define q1, but don't start it for now + val q1Write = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + var q1: StreamingQuery = null + + val q2 = createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + testStream(q2)( + AssertOnQuery { q2 => + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has not started yet, verify that q2 doesn't know whether q1 has metadata + fileSource.sourceHasMetadata === None + }, + Execute { _ => + q1 = q1Write.start(outputDir.getCanonicalPath) + q1Source.addData("drop1", "keep2") + q1.processAllAvailable() + }, + AssertOnQuery { q2 => + q2.processAllAvailable() + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has started, verify that q2 knows q1 has metadata by now + fileSource.sourceHasMetadata === Some(true) + }, + CheckAnswer("keep2"), + Execute { _ => q1.stop() } + ) + } + } + test("when schema inference is turned on, should read partition data") { def createFile(content: String, src: File, tmp: File): Unit = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) @@ -755,10 +853,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .streamingQuery q.processAllAvailable() val memorySink = q.sink.asInstanceOf[MemorySink] - val fileSource = q.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - }.head + val fileSource = getSourcesFromStreamingQuery(q).head /** Check the data read in the last batch */ def checkLastBatchData(data: Int*): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 0524898b15ea..6cf4d51f9933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { +class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } - - private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") - assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") - true - } } object MapGroupsWithStateSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala new file mode 100644 index 000000000000..894786c50e23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +trait StateStoreMetricsTest extends StreamTest { + + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = + AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert( + progressWithData.stateOperators.map(_.numRowsTotal) === total, + "incorrect total rows") + assert( + progressWithData.stateOperators.map(_.numRowsUpdated) === updated, + "incorrect updates rows") + true + } + + def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = + assertNumStateRows(Seq(total), Seq(updated)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 0296a2ade345..6dfcd8baba20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming +import java.io.{InterruptedIOException, IOException} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} + import scala.reflect.ClassTag import scala.util.control.ControlThrowable @@ -338,7 +341,7 @@ class StreamSuite extends StreamTest { .writeStream .format("memory") .queryName("testquery") - .outputMode("complete") + .outputMode("append") .start() try { query.processAllAvailable() @@ -350,13 +353,45 @@ class StreamSuite extends StreamTest { } } } -} -/** - * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. - */ -class FakeDefaultSource extends StreamSourceProvider { + test("handle IOException when the streaming thread is interrupted (pre Hadoop 2.8)") { + // This test uses a fake source to throw the same IOException as pre Hadoop 2.8 when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingIOExceptionLikeHadoop12074.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingIOExceptionLikeHadoop12074].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingIOExceptionLikeHadoop12074.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingIOExceptionLikeHadoop12074.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + test("handle InterruptedIOException when the streaming thread is interrupted (Hadoop 2.8+)") { + // This test uses a fake source to throw the same InterruptedIOException as Hadoop 2.8+ when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingInterruptedIOException.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingInterruptedIOException].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingInterruptedIOException.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingInterruptedIOException.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } +} + +abstract class FakeSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( @@ -364,6 +399,10 @@ class FakeDefaultSource extends StreamSourceProvider { schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) +} + +/** A fake StreamSourceProvider that creates a fake Source that cannot be reused. */ +class FakeDefaultSource extends FakeSource { override def createSource( spark: SQLContext, @@ -395,3 +434,63 @@ class FakeDefaultSource extends StreamSourceProvider { } } } + +/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */ +class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { + import ThrowingIOExceptionLikeHadoop12074._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + throw new IOException(ie.toString) + } + } +} + +object ThrowingIOExceptionLikeHadoop12074 { + /** + * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} + +/** A fake source that throws InterruptedIOException like Hadoop 2.8+ when it's interrupted. */ +class ThrowingInterruptedIOException extends FakeSource { + import ThrowingInterruptedIOException._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + val iie = new InterruptedIOException(ie.toString) + iie.initCause(ie) + throw iie + } + } +} + +object ThrowingInterruptedIOException { + /** + * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af2f31a34d8d..60e2375a9817 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -208,6 +208,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + /** Execute arbitrary code */ + object Execute { + def apply(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }) + } + class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { private var waitStartTime: Option[Long] = None @@ -472,7 +478,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AssertOnQuery => verify(currentStream != null || lastStream != null, - "cannot assert when not stream has been started") + "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index eca2647dea52..0c8015672bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -35,7 +35,7 @@ object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { +class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 4596aa1d348e..eb09b9ffcfc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -133,6 +133,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("SPARK-19594: all of listeners should receive QueryTerminatedEvent") { + val df = MemoryStream[Int].toDS().as[Long] + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append)( + StartStream(), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1525ad5fd517..a0a2b2b4c9b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils +import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.mock.MockitoSugar import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,11 +34,11 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.BlockingSource +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} import org.apache.spark.util.ManualClock -class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { import AwaitTerminationTester._ import testImplicits._ @@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { } } + test("StreamExecution should call stop() on sources when a stream is stopped") { + var calledStop = false + val source = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + + testStream(df)(StopStream) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + + testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") { + var calledStop = false + val source1 = new Source { + override def stop(): Unit = { + throw new RuntimeException("Oh no!") + } + override def getOffset: Option[Offset] = Some(LongOffset(1)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*) + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + val source2 = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source1, source2) { + val df1 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + val df2 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + testStream(df1.union(df2).map(i => i / 0))( + AssertOnQuery { sq => + intercept[StreamingQueryException](sq.processAllAvailable()) + sq.exception.isDefined && !sq.isActive + } + ) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala new file mode 100644 index 000000000000..0bf05381a7f3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * A StreamSourceProvider that provides mocked Sources for unit testing. Example usage: + * + * {{{ + * MockSourceProvider.withMockSources(source1, source2) { + * val df1 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * val df2 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * df1.union(df2) + * ... + * } + * }}} + */ +class MockSourceProvider extends StreamSourceProvider { + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + ("dummySource", MockSourceProvider.fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + MockSourceProvider.sourceProviderFunction() + } +} + +object MockSourceProvider { + // Function to generate sources. May provide multiple sources if the user implements such a + // function. + private var sourceProviderFunction: () => Source = _ + + final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = { + var i = 0 + val sources = source +: otherSources + sourceProviderFunction = () => { + val source = sources(i % sources.length) + i += 1 + source + } + try { + f + } finally { + sourceProviderFunction = null + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 2f247ca3e8b7..8ab6db175da5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,7 +35,7 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { new SQLConf { clear() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 78a309497ab5..c0b299411e94 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -40,6 +40,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeAppName = sparkConf .getOption("spark.app.name") .filterNot(_ == classOf[SparkSQLCLIDriver].getName) + .filterNot(_ == classOf[HiveThriftServer2].getName) sparkConf .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 7ba5790c2979..c7d953a731b9 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -95,9 +95,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // This is used to generate golden files. sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - sql(s"set fs.default.name=file://$testTempDir/") + sql(s"set fs.defaultFS=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - sql("set mapred.job.tracker=local") + sql("set mapreduce.jobtracker.address=local") } override def afterAll() { @@ -764,9 +764,9 @@ class HiveWindowFunctionQueryFileSuite // This is used to generate golden files. // sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - // sql(s"set fs.default.name=file://$testTempDir/") + // sql(s"set fs.defaultFS=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - // sql("set mapred.job.tracker=local") + // sql("set mapreduce.jobtracker.address=local") } override def afterAll() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ea4825614785..43d9c2bec682 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient @@ -736,14 +736,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = withClient { requireTableExists(db, table) client.loadTable( loadPath, s"$db.$table", isOverwrite, - holdDDLTime, isSrcLocal) } @@ -753,7 +751,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = withClient { requireTableExists(db, table) @@ -773,7 +770,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table, orderedPartitionSpec, isOverwrite, - holdDDLTime, inheritTableSpecs, isSrcLocal) } @@ -784,8 +780,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = withClient { + numDP: Int): Unit = withClient { requireTableExists(db, table) val orderedPartitionSpec = new util.LinkedHashMap[String, String]() @@ -803,8 +798,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table, orderedPartitionSpec, replace, - numDP, - holdDDLTime) + numDP) } // -------------------------------------------------------------------------- @@ -1014,7 +1008,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) val partitionColumnNames = catalogTable.partitionColumnNames.toSet @@ -1040,7 +1035,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) - clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) } + clientPrunedPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } } else { client.getPartitions(catalogTable).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 677da0dbdc65..151a69aebf1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path @@ -26,6 +28,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.orc.OrcFileFormat @@ -71,10 +74,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def getCached( tableIdentifier: QualifiedTableName, pathsInMetastore: Seq[Path], - metastoreRelation: MetastoreRelation, schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], - expectedBucketSpec: Option[BucketSpec], partitionSchema: Option[StructType]): Option[LogicalRelation] = { tableRelationCache.getIfPresent(tableIdentifier) match { @@ -89,7 +90,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val useCached = relation.location.rootPaths.toSet == pathsInMetastore.toSet && logical.schema.sameType(schemaInMetastore) && - relation.bucketSpec == expectedBucketSpec && + // We don't support hive bucketed tables. This function `getCached` is only used for + // converting supported Hive tables to data source tables. + relation.bucketSpec.isEmpty && relation.partitionSchema == partitionSchema.getOrElse(StructType(Nil)) if (useCached) { @@ -100,52 +103,48 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log None } case _ => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} " + - s"should be stored as $expectedFileFormat. However, we are getting " + - s"a ${relation.fileFormat} from the metastore cache. This cached " + - s"entry will be invalidated.") + logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + + "This cached entry will be invalidated.") tableRelationCache.invalidate(tableIdentifier) None } case other => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as $expectedFileFormat. However, we are getting a $other from the metastore cache. " + - s"This cached entry will be invalidated.") + logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + + s"However, we are getting a $other from the metastore cache. " + + "This cached entry will be invalidated.") tableRelationCache.invalidate(tableIdentifier) None } } private def convertToLogicalRelation( - metastoreRelation: MetastoreRelation, + relation: CatalogRelation, options: Map[String, String], - defaultSource: FileFormat, fileFormatClass: Class[_ <: FileFormat], fileType: String): LogicalRelation = { - val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val metastoreSchema = relation.tableMeta.schema val tableIdentifier = - QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) - val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. + QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions - val result = if (metastoreRelation.hiveQlTable.isPartitioned) { - val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) - + val tablePath = new Path(new URI(relation.tableMeta.location)) + val result = if (relation.isPartitioned) { + val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { - Seq(metastoreRelation.hiveQlTable.getDataLocation) + Seq(tablePath) } else { // By convention (for example, see CatalogFileIndex), the definition of a // partitioned table's paths depends on whether that table has any actual partitions. // Partitioned tables without partitions use the location of the table's base path. // Partitioned tables with partitions use the locations of those partitions' data // locations,_omitting_ the table's base path. - val paths = metastoreRelation.getHiveQlPartitions().map { p => - new Path(p.getLocation) - } + val paths = sparkSession.sharedState.externalCatalog + .listPartitions(tableIdentifier.database, tableIdentifier.name) + .map(p => new Path(new URI(p.storage.locationUri.get))) + if (paths.isEmpty) { - Seq(metastoreRelation.hiveQlTable.getDataLocation) + Seq(tablePath) } else { paths } @@ -155,39 +154,31 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val cached = getCached( tableIdentifier, rootPaths, - metastoreRelation, metastoreSchema, fileFormatClass, - bucketSpec, Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = - metastoreRelation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong val fileIndex = { - val index = new CatalogFileIndex( - sparkSession, metastoreRelation.catalogTable, sizeInBytes) + val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { index } else { index.filterPartitions(Nil) // materialize all the partitions in memory } } - val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet - val dataSchema = - StructType(metastoreSchema - .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) - val relation = HadoopFsRelation( + val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = dataSchema, - bucketSpec = bucketSpec, - fileFormat = defaultSource, + dataSchema = relation.tableMeta.dataSchema, + // We don't support hive bucketed tables, only ones we write out. + bucketSpec = None, + fileFormat = fileFormatClass.newInstance(), options = options)(sparkSession = sparkSession) - val created = LogicalRelation(relation, - catalogTable = Some(metastoreRelation.catalogTable)) + val created = LogicalRelation(fsRelation, catalogTable = Some(relation.tableMeta)) tableRelationCache.put(tableIdentifier, created) created } @@ -195,14 +186,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } else { - val rootPath = metastoreRelation.hiveQlTable.getDataLocation + val rootPath = tablePath withTableCreationLock(tableIdentifier, { - val cached = getCached(tableIdentifier, + val cached = getCached( + tableIdentifier, Seq(rootPath), - metastoreRelation, metastoreSchema, fileFormatClass, - bucketSpec, None) val logicalRelation = cached.getOrElse { val created = @@ -210,11 +200,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Some(metastoreRelation.schema), - bucketSpec = bucketSpec, + userSpecifiedSchema = Some(metastoreSchema), + // We don't support hive bucketed tables, only ones we write out. + bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(metastoreRelation.catalogTable)) + catalogTable = Some(relation.tableMeta)) tableRelationCache.put(tableIdentifier, created) created @@ -223,7 +214,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } - result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) + result.copy(expectedOutputAttributes = Some(relation.output)) } /** @@ -231,33 +222,32 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log * data source relations for better performance. */ object ParquetConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && + private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { + relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && sessionState.convertMetastoreParquet } - private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new ParquetFileFormat() + private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) - convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "parquet") + convertToLogicalRelation(relation, options, fileFormatClass, "parquet") } override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if query.resolved && !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && shouldConvertMetastoreParquet(r) => InsertIntoTable(convertToParquetRelation(r), partition, query, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => - val parquetRelation = convertToParquetRelation(relation) - SubqueryAlias(relation.tableName, parquetRelation, None) + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && + shouldConvertMetastoreParquet(relation) => + convertToParquetRelation(relation) } } } @@ -267,31 +257,31 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log * for better performance. */ object OrcConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { - relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && + private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { + relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && sessionState.convertMetastoreOrc } - private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new OrcFileFormat() + private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[OrcFileFormat] val options = Map[String, String]() - convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc") + convertToLogicalRelation(relation, options, fileFormatClass, "orc") } override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Orc data source (yet). - if query.resolved && !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && shouldConvertMetastoreOrc(r) => InsertIntoTable(convertToOrcRelation(r), partition, query, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => - val orcRelation = convertToOrcRelation(relation) - SubqueryAlias(relation.tableName, orcRelation, None) + case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && + shouldConvertMetastoreOrc(relation) => + convertToOrcRelation(relation) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 273cf85df33a..5a08a6bc66f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -62,10 +62,10 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override val extendedResolutionRules = new ResolveHiveSerdeTable(sparkSession) :: new FindDataSourceTable(sparkSession) :: - new FindHiveSerdeTable(sparkSession) :: new ResolveSQLOnFile(sparkSession) :: Nil override val postHocResolutionRules = + new DetermineTableStats(sparkSession) :: catalog.ParquetConversions :: catalog.OrcConversions :: PreprocessTableCreation(sparkSession) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index f45532cc3845..624cfa206eeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,8 +17,14 @@ package org.apache.spark.sql.hive +import java.io.IOException +import java.net.URI + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation} @@ -91,18 +97,56 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { // Infers the schema, if empty, because the schema could be determined by Hive // serde. - val catalogTable = if (query.isEmpty) { - val withSchema = HiveUtils.inferSchema(withStorage) - if (withSchema.schema.length <= 0) { + val withSchema = if (query.isEmpty) { + val inferred = HiveUtils.inferSchema(withStorage) + if (inferred.schema.length <= 0) { throw new AnalysisException("Unable to infer the schema. " + - s"The schema specification is required to create the table ${withSchema.identifier}.") + s"The schema specification is required to create the table ${inferred.identifier}.") } - withSchema + inferred } else { withStorage } - c.copy(tableDesc = catalogTable) + c.copy(tableDesc = withSchema) + } +} + +class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case relation: CatalogRelation + if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => + val table = relation.tableMeta + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. + // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys + // (see StatsSetupConst in Hive) that we can look at in the future. + // When table is external,`totalSize` is always zero, which will influence join strategy + // so when `totalSize` is zero, use `rawDataSize` instead. + val totalSize = table.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val rawDataSize = table.properties.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) + val sizeInBytes = if (totalSize.isDefined && totalSize.get > 0) { + totalSize.get + } else if (rawDataSize.isDefined && rawDataSize.get > 0) { + rawDataSize.get + } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { + try { + val hadoopConf = session.sessionState.newHadoopConf() + val tablePath = new Path(new URI(table.location)) + val fs: FileSystem = tablePath.getFileSystem(hadoopConf) + fs.getContentSummary(tablePath).getLength + } catch { + case e: IOException => + logWarning("Failed to get table size from hdfs.", e) + session.sessionState.conf.defaultSizeInBytes + } + } else { + session.sessionState.conf.defaultSizeInBytes + } + + val withStats = table.copy(stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes)))) + relation.copy(tableMeta = withStats) } } @@ -114,8 +158,9 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { */ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(table: MetastoreRelation, partSpec, query, overwrite, ifNotExists) => - InsertIntoHiveTable(table, partSpec, query, overwrite, ifNotExists) + case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) + if DDLUtils.isHiveTable(relation.tableMeta) => + InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -125,21 +170,6 @@ object HiveAnalysis extends Rule[LogicalPlan] { } } -/** - * Replaces `SimpleCatalogRelation` with [[MetastoreRelation]] if its table provider is hive. - */ -class FindHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) - if DDLUtils.isHiveTable(s.metadata) => - i.copy(table = - MetastoreRelation(s.metadata.database, s.metadata.identifier.table)(s.metadata, session)) - - case s: SimpleCatalogRelation if DDLUtils.isHiveTable(s.metadata) => - MetastoreRelation(s.metadata.database, s.metadata.identifier.table)(s.metadata, session) - } -} - private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => @@ -161,10 +191,10 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => + case PhysicalOperation(projectList, predicates, relation: CatalogRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. - val partitionKeyIds = AttributeSet(relation.partitionKeys) + val partitionKeyIds = AttributeSet(relation.partitionCols) val (pruningPredicates, otherPredicates) = predicates.partition { predicate => !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala deleted file mode 100644 index 97b120758ba4..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException - -import com.google.common.base.Objects -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable} -import org.apache.hadoop.hive.ql.plan.TableDesc - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.hive.client.HiveClientImpl -import org.apache.spark.sql.types.StructField - - -private[hive] case class MetastoreRelation( - databaseName: String, - tableName: String) - (val catalogTable: CatalogTable, - @transient private val sparkSession: SparkSession) - extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation { - - override def equals(other: Any): Boolean = other match { - case relation: MetastoreRelation => - databaseName == relation.databaseName && - tableName == relation.tableName && - output == relation.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(databaseName, tableName, output) - } - - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil - - @transient val hiveQlTable: HiveTable = HiveClientImpl.toHiveTable(catalogTable) - - @transient override def computeStats(conf: CatalystConf): Statistics = { - catalogTable.stats.map(_.toPlanStats(output)).getOrElse(Statistics( - sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) - val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) - // TODO: check if this estimate is valid for tables after partition pruning. - // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be - // relatively cheap if parameters for the table are populated into the metastore. - // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys - // (see StatsSetupConst in Hive) that we can look at in the future. - BigInt( - // When table is external,`totalSize` is always zero, which will influence join strategy - // so when `totalSize` is zero, use `rawDataSize` instead - // when `rawDataSize` is also zero, use `HiveExternalCatalog.STATISTICS_TOTAL_SIZE`, - // which is generated by analyze command. - if (totalSize != null && totalSize.toLong > 0L) { - totalSize.toLong - } else if (rawDataSize != null && rawDataSize.toLong > 0) { - rawDataSize.toLong - } else if (sparkSession.sessionState.conf.fallBackToHdfsForStatsEnabled) { - try { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val fs: FileSystem = hiveQlTable.getPath.getFileSystem(hadoopConf) - fs.getContentSummary(hiveQlTable.getPath).getLength - } catch { - case e: IOException => - logWarning("Failed to get table size from hdfs.", e) - sparkSession.sessionState.conf.defaultSizeInBytes - } - } else { - sparkSession.sessionState.conf.defaultSizeInBytes - }) - } - )) - } - - // When metastore partition pruning is turned off, we cache the list of all partitions to - // mimic the behavior of Spark < 1.5 - private lazy val allPartitions: Seq[CatalogTablePartition] = { - sparkSession.sharedState.externalCatalog.listPartitions( - catalogTable.database, - catalogTable.identifier.table) - } - - def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - catalogTable.database, - catalogTable.identifier.table, - predicates) - } else { - allPartitions - } - - rawPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) - } - - /** Only compare database and tablename, not alias. */ - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case mr: MetastoreRelation => - mr.databaseName == databaseName && mr.tableName == tableName - case _ => false - } - } - - val tableDesc = new TableDesc( - hiveQlTable.getInputFormatClass, - // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because - // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to - // substitute some output formats, e.g. substituting SequenceFileOutputFormat to - // HiveSequenceFileOutputFormat. - hiveQlTable.getOutputFormatClass, - hiveQlTable.getMetadata - ) - - implicit class SchemaAttribute(f: StructField) { - def toAttribute: AttributeReference = AttributeReference( - f.name, - f.dataType, - // Since data can be dumped in randomly with no validation, everything is nullable. - nullable = true - )(qualifier = Some(tableName)) - } - - /** PartitionKey attributes */ - val partitionKeys = catalogTable.partitionSchema.map(_.toAttribute) - - /** Non-partitionKey attributes */ - val dataColKeys = catalogTable.schema - .filter { c => !catalogTable.partitionColumnNames.contains(c.name) } - .map(_.toAttribute) - - val output = dataColKeys ++ partitionKeys - - /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o, o))) - - /** An attribute map for determining the ordinal for non-partition columns. */ - val columnOrdinals = AttributeMap(dataColKeys.zipWithIndex) - - override def inputFiles: Array[String] = { - val partLocations = allPartitions - .flatMap(_.storage.locationUri) - .toArray - if (partLocations.nonEmpty) { - partLocations - } else { - Array( - catalogTable.storage.locationUri.getOrElse( - sys.error(s"Could not get the location of ${catalogTable.qualifiedName}."))) - } - } - - override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName)(catalogTable, sparkSession) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b4b63032ab26..16c1103dd1ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -61,19 +61,22 @@ private[hive] sealed trait TableReader { private[hive] class HadoopTableReader( @transient private val attributes: Seq[Attribute], - @transient private val relation: MetastoreRelation, + @transient private val partitionKeys: Seq[Attribute], + @transient private val tableDesc: TableDesc, @transient private val sparkSession: SparkSession, hadoopConf: Configuration) extends TableReader with Logging { - // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". - // https://hadoop.apache.org/docs/r1.0.4/mapred-default.html + // Hadoop honors "mapreduce.job.maps" as hint, + // but will ignore when mapreduce.jobtracker.address is "local". + // https://hadoop.apache.org/docs/r2.6.5/hadoop-mapreduce-client/hadoop-mapreduce-client-core/ + // mapred-default.xml // // In order keep consistency with Hive, we will let it be 0 in local mode also. private val _minSplitsPerRDD = if (sparkSession.sparkContext.isLocal) { 0 // will splitted based on block by default. } else { - math.max(hadoopConf.getInt("mapred.map.tasks", 1), + math.max(hadoopConf.getInt("mapreduce.job.maps", 1), sparkSession.sparkContext.defaultMinPartitions) } @@ -86,7 +89,7 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], + Utils.classForName(tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -108,7 +111,7 @@ class HadoopTableReader( // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = relation.tableDesc + val localTableDesc = tableDesc val broadcastedHadoopConf = _broadcastedHadoopConf val tablePath = hiveTable.getPath @@ -117,7 +120,7 @@ class HadoopTableReader( // logDebug("Table input: %s".format(tablePath)) val ifc = hiveTable.getInputFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val hadoopRDD = createHadoopRdd(localTableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) @@ -125,7 +128,7 @@ class HadoopTableReader( val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value val deserializer = deserializerClass.newInstance() - deserializer.initialize(hconf, tableDesc.getProperties) + deserializer.initialize(hconf, localTableDesc.getProperties) HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } @@ -210,8 +213,6 @@ class HadoopTableReader( partCols.map(col => new String(partSpec.get(col))).toArray } - // Create local references so that the outer object isn't serialized. - val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHadoopConf val localDeserializer = partDeserializer val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) @@ -220,12 +221,12 @@ class HadoopTableReader( // Attached indices indicate the position of each attribute in the output schema. val (partitionKeyAttrs, nonPartitionKeyAttrs) = attributes.zipWithIndex.partition { case (attr, _) => - relation.partitionKeys.contains(attr) + partitionKeys.contains(attr) } def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => - val partOrdinal = relation.partitionKeys.indexOf(attr) + val partOrdinal = partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } @@ -233,9 +234,11 @@ class HadoopTableReader( // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val tableProperties = relation.tableDesc.getProperties + val tableProperties = tableDesc.getProperties - createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => + // Create local references so that the outer object isn't serialized. + val localTableDesc = tableDesc + createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema @@ -249,8 +252,8 @@ class HadoopTableReader( } deserializer.initialize(hconf, props) // get the table deserializer - val tableSerDe = tableDesc.getDeserializerClass.newInstance() - tableSerDe.initialize(hconf, tableDesc.getProperties) + val tableSerDe = localTableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 8bdcf3111d8e..16a80f9fff45 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -208,7 +208,6 @@ private[hive] trait HiveClient { tableName: String, partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit @@ -217,7 +216,6 @@ private[hive] trait HiveClient { loadPath: String, // TODO URI tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit /** Loads new dynamic partitions into an existing table. */ @@ -227,8 +225,7 @@ private[hive] trait HiveClient { tableName: String, partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit + numDP: Int): Unit /** Create a function in an existing database. */ def createFunction(db: String, func: CatalogFunction): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index dc9c3ff33542..8f98c8f44703 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -96,6 +96,7 @@ private[hive] class HiveClientImpl( case hive.v1_0 => new Shim_v1_0() case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() + case hive.v2_0 => new Shim_v2_0() } // Create an internal session state for this HiveClientImpl. @@ -106,10 +107,6 @@ private[hive] class HiveClientImpl( // Set up kerberos credentials for UserGroupInformation.loginUser within // current class loader - // Instead of using the spark conf of the current spark context, a new - // instance of SparkConf is needed for the original value of spark.yarn.keytab - // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the - // keytab configuration for the link name in distributed cache if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { val principalName = sparkConf.get("spark.yarn.principal") val keytabFileName = sparkConf.get("spark.yarn.keytab") @@ -668,7 +665,6 @@ private[hive] class HiveClientImpl( tableName: String, partSpec: java.util.LinkedHashMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = withHiveState { val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) @@ -678,7 +674,6 @@ private[hive] class HiveClientImpl( s"$dbName.$tableName", partSpec, replace, - holdDDLTime, inheritTableSpecs, isSkewedStoreAsSubdir = hiveTable.isStoredAsSubDirectories, isSrcLocal = isSrcLocal) @@ -688,14 +683,12 @@ private[hive] class HiveClientImpl( loadPath: String, // TODO URI tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = withHiveState { shim.loadTable( client, new Path(loadPath), tableName, replace, - holdDDLTime, isSrcLocal) } @@ -705,8 +698,7 @@ private[hive] class HiveClientImpl( tableName: String, partSpec: java.util.LinkedHashMap[String, String], replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = withHiveState { + numDP: Int): Unit = withHiveState { val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) shim.loadDynamicPartitions( client, @@ -715,7 +707,6 @@ private[hive] class HiveClientImpl( partSpec, replace, numDP, - holdDDLTime, listBucketingEnabled = hiveTable.isStoredAsSubDirectories) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index b052f1e7e43f..7280748361d6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -96,7 +96,6 @@ private[client] sealed abstract class Shim { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit @@ -106,7 +105,6 @@ private[client] sealed abstract class Shim { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit def loadDynamicPartitions( @@ -116,7 +114,6 @@ private[client] sealed abstract class Shim { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit def createFunction(hive: Hive, db: String, func: CatalogFunction): Unit @@ -332,12 +329,11 @@ private[client] class Shim_v0_12 extends Shim with Logging { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) + JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) } override def loadTable( @@ -345,9 +341,8 @@ private[client] class Shim_v0_12 extends Shim with Logging { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -357,10 +352,9 @@ private[client] class Shim_v0_12 extends Shim with Logging { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean) } override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { @@ -703,12 +697,11 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, isSrcLocal: JBoolean, JBoolean.FALSE) } @@ -717,9 +710,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE, isSrcLocal: JBoolean, JBoolean.FALSE, JBoolean.FALSE) } @@ -730,10 +722,9 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE) } override def dropTable( @@ -818,10 +809,9 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) } @@ -843,3 +833,77 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { } } + +private[client] class Shim_v2_0 extends Shim_v1_2 { + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index d2487a2c034c..6f69a4adf29d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -94,6 +94,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 + case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 4e2193b6abc3..790ad74e6639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive /** Support for interacting with different versions of the HiveMetastoreClient */ package object client { - private[hive] abstract class HiveVersion( + private[hive] sealed abstract class HiveVersion( val fullVersion: String, val extraDeps: Seq[String] = Nil, val exclusions: Seq[String] = Nil) @@ -62,6 +62,12 @@ package object client { "org.pentaho:pentaho-aggdesigner-algorithm", "net.hydromatic:linq4j", "net.hydromatic:quidem")) + + case object v2_0 extends HiveVersion("2.0.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 140c352fa6f8..28f074849c0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption @@ -29,10 +30,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -46,12 +49,12 @@ import org.apache.spark.util.Utils private[hive] case class HiveTableScanExec( requestedAttributes: Seq[Attribute], - relation: MetastoreRelation, + relation: CatalogRelation, partitionPruningPred: Seq[Expression])( @transient private val sparkSession: SparkSession) extends LeafExecNode { - require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + require(partitionPruningPred.isEmpty || relation.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") override lazy val metrics = Map( @@ -60,42 +63,54 @@ case class HiveTableScanExec( override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(partitionPruningPred.flatMap(_.references)) - // Retrieve the original attributes based on expression ID so that capitalization matches. - val attributes = requestedAttributes.map(relation.attributeMap) + private val originalAttributes = AttributeMap(relation.output.map(a => a -> a)) + + override val output: Seq[Attribute] = { + // Retrieve the original attributes based on expression ID so that capitalization matches. + requestedAttributes.map(originalAttributes) + } // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") - BindReferences.bindReference(pred, relation.partitionKeys) + BindReferences.bindReference(pred, relation.partitionCols) } // Create a local copy of hadoopConf,so that scan specific modifications should not impact // other queries - @transient - private[this] val hadoopConf = sparkSession.sessionState.newHadoopConf() + @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() + + @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata) // append columns ids and names before broadcast addColumnMetadataToConf(hadoopConf) - @transient - private[this] val hadoopReader = - new HadoopTableReader(attributes, relation, sparkSession, hadoopConf) + @transient private val hadoopReader = new HadoopTableReader( + output, + relation.partitionCols, + tableDesc, + sparkSession, + hadoopConf) - private[this] def castFromString(value: String, dataType: DataType) = { + private def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) } private def addColumnMetadataToConf(hiveConf: Configuration) { // Specifies needed column IDs for those non-partitioning columns. - val neededColumnIDs = attributes.flatMap(relation.columnOrdinals.get).map(o => o: Integer) + val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) + val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) - HiveShim.appendReadColumns(hiveConf, neededColumnIDs, attributes.map(_.name)) + HiveShim.appendReadColumns(hiveConf, neededColumnIDs, output.map(_.name)) - val tableDesc = relation.tableDesc val deserializer = tableDesc.getDeserializerClass.newInstance deserializer.initialize(hiveConf, tableDesc.getProperties) @@ -113,7 +128,7 @@ case class HiveTableScanExec( .mkString(",") hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataColKeys.map(_.name).mkString(",")) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataCols.map(_.name).mkString(",")) } /** @@ -126,7 +141,7 @@ case class HiveTableScanExec( boundPruningPred match { case None => partitions case Some(shouldKeep) => partitions.filter { part => - val dataTypes = relation.partitionKeys.map(_.dataType) + val dataTypes = relation.partitionCols.map(_.dataType) val castedValues = part.getValues.asScala.zip(dataTypes) .map { case (value, dataType) => castFromString(value, dataType) } @@ -138,27 +153,36 @@ case class HiveTableScanExec( } } + // exposed for tests + @transient lazy val rawPartitions = { + val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sharedState.externalCatalog.listPartitionsByFilter( + relation.tableMeta.database, + relation.tableMeta.identifier.table, + normalizedFilters, + sparkSession.sessionState.conf.sessionLocalTimeZone) + } else { + sparkSession.sharedState.externalCatalog.listPartitions( + relation.tableMeta.database, + relation.tableMeta.identifier.table) + } + prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) + } + protected override def doExecute(): RDD[InternalRow] = { // Using dummyCallSite, as getCallSite can turn out to be expensive with // with multiple partitions. - val rdd = if (!relation.hiveQlTable.isPartitioned) { + val rdd = if (!relation.isPartitioned) { Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) + hadoopReader.makeRDDForTable(hiveQlTable) } } else { - // The attribute name of predicate could be different than the one in schema in case of - // case insensitive, we should change them to match the one in schema, so we do not need to - // worry about case sensitivity anymore. - val normalizedFilters = partitionPruningPred.map { e => - e transform { - case a: AttributeReference => - a.withName(relation.output.find(_.semanticEquals(a)).get.name) - } - } - Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(normalizedFilters))) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(rawPartitions)) } } val numOutputRows = longMetric("numOutputRows") @@ -174,8 +198,6 @@ case class HiveTableScanExec( } } - override def output: Seq[Attribute] = attributes - override def sameResult(plan: SparkPlan): Boolean = plan match { case other: HiveTableScanExec => val thisPredicates = partitionPruningPred.map(cleanExpression) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3e654d8eeb35..3c57ee4c8b8f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -29,16 +29,18 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.hadoop.hive.ql.ErrorMsg +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.client.HiveVersion +import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion} import org.apache.spark.SparkException @@ -52,9 +54,7 @@ import org.apache.spark.SparkException * In the future we should converge the write path for Hive with the normal data source write path, * as defined in `org.apache.spark.sql.execution.datasources.FileFormatWriter`. * - * @param table the logical plan representing the table. In the future this should be a - * `org.apache.spark.sql.catalyst.catalog.CatalogTable` once we converge Hive tables - * and data source tables. + * @param table the metadata of the table. * @param partition a map from the partition key to the partition value (optional). If the partition * value is optional, dynamic partition insert will be performed. * As an example, `INSERT INTO tbl PARTITION (a=1, b=2) AS ...` would have @@ -74,7 +74,7 @@ import org.apache.spark.SparkException * @param ifNotExists If true, only write if the table or partition does not exist. */ case class InsertIntoHiveTable( - table: MetastoreRelation, + table: CatalogTable, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, @@ -148,9 +148,16 @@ case class InsertIntoHiveTable( // We have to follow the Hive behavior here, to avoid troubles. For example, if we create // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. - if (hiveVersion == v12 || hiveVersion == v13 || hiveVersion == v14 || hiveVersion == v1_0) { + val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + + // Ensure all the supported versions are considered here. + assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == + allSupportedHiveVersions) + + if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) { oldVersionExternalTempPath(path, hadoopConf, scratchDir) - } else if (hiveVersion == v1_1 || hiveVersion == v1_2) { + } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) { newVersionExternalTempPath(path, hadoopConf, stagingDir) } else { throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion) @@ -218,23 +225,35 @@ case class InsertIntoHiveTable( val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") + val hiveQlTable = HiveClientImpl.toHiveTable(table) // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. - val tableDesc = table.tableDesc - val tableLocation = table.hiveQlTable.getDataLocation + val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because + // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to + // substitute some output formats, e.g. substituting SequenceFileOutputFormat to + // HiveSequenceFileOutputFormat. + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + val tableLocation = hiveQlTable.getDataLocation val tmpLocation = getExternalTmpPath(tableLocation, hiveVersion, hadoopConf, stagingDir, scratchDir) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean if (isCompressed) { - // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", - // and "mapred.output.compression.type" have no impact on ORC because it uses table properties - // to store compression information. - hadoopConf.set("mapred.output.compress", "true") + // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact on ORC because it uses table properties to store compression information. + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(hadoopConf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(hadoopConf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.codec")) + fileSinkConf.setCompressType(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.type")) } val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -251,9 +270,9 @@ case class InsertIntoHiveTable( // By this time, the partition map must match the table's partition columns if (partitionColumnNames.toSet != partition.keySet) { throw new SparkException( - s"""Requested partitioning does not match the ${table.tableName} table: + s"""Requested partitioning does not match the ${table.identifier.table} table: |Requested partitions: ${partition.keys.mkString(",")} - |Table partitions: ${table.partitionKeys.map(_.name).mkString(",")}""".stripMargin) + |Table partitions: ${table.partitionColumnNames.mkString(",")}""".stripMargin) } // Validate partition spec if there exist any dynamic partitions @@ -301,20 +320,15 @@ case class InsertIntoHiveTable( refreshFunction = _ => (), options = Map.empty) - // TODO: Correctly set holdDDLTime. - // In most of the time, we should have holdDDLTime = false. - // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. - val holdDDLTime = false if (partition.nonEmpty) { if (numDynamicPartitions > 0) { externalCatalog.loadDynamicPartitions( - db = table.catalogTable.database, - table = table.catalogTable.identifier.table, + db = table.database, + table = table.identifier.table, tmpLocation.toString, partitionSpec, overwrite, - numDynamicPartitions, - holdDDLTime = holdDDLTime) + numDynamicPartitions) } else { // scalastyle:off // ifNotExists is only valid with static partition, refer to @@ -322,8 +336,8 @@ case class InsertIntoHiveTable( // scalastyle:on val oldPart = externalCatalog.getPartitionOption( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, partitionSpec) var doHiveOverwrite = overwrite @@ -352,23 +366,21 @@ case class InsertIntoHiveTable( // which is currently considered as a Hive native command. val inheritTableSpecs = true externalCatalog.loadPartition( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, tmpLocation.toString, partitionSpec, isOverwrite = doHiveOverwrite, - holdDDLTime = holdDDLTime, inheritTableSpecs = inheritTableSpecs, isSrcLocal = false) } } } else { externalCatalog.loadTable( - table.catalogTable.database, - table.catalogTable.identifier.table, + table.database, + table.identifier.table, tmpLocation.toString, // TODO: URI overwrite, - holdDDLTime, isSrcLocal = false) } @@ -382,8 +394,8 @@ case class InsertIntoHiveTable( } // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(table) - sparkSession.sessionState.catalog.refreshTable(table.catalogTable.identifier) + sparkSession.catalog.uncacheTable(table.qualifiedName) + sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3267c237c865..efc2f0098454 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -439,7 +439,7 @@ private[hive] class TestHiveSparkSession( foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. - sessionState.conf.setConfString("fs.default.name", new File(".").toURI.toString) + sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. sessionState.metadataHive.runSqlHive("RESET") @@ -483,7 +483,7 @@ private[hive] class TestHiveQueryExecution( // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } + logical.collect { case UnresolvedRelation(tableIdent) => tableIdent.table } val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index f664d5a4cdad..aefc9cc77da8 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -26,7 +26,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; @@ -35,7 +34,6 @@ import org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { - private transient JavaSparkContext sc; private transient SQLContext hc; Dataset df; @@ -50,13 +48,11 @@ private static void checkAnswer(Dataset actual, List expected) { @Before public void setUp() throws IOException { hc = TestHive$.MODULE$; - sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } - df = hc.read().json(sc.parallelize(jsonObjects)); + df = hc.read().json(hc.createDataset(jsonObjects, Encoders.STRING())); df.createOrReplaceTempView("window_table"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 061c7431a636..0b157a45e6e0 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -31,9 +31,9 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -81,8 +81,8 @@ public void setUp() throws IOException { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + Dataset ds = sqlContext.createDataset(jsonObjects, Encoders.STRING()); + df = sqlContext.read().json(ds); df.createOrReplaceTempView("jsonTable"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java index 6adb1657bf25..8211cbf16f7b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawList.java @@ -25,6 +25,7 @@ * UDF that returns a raw (non-parameterized) java List. */ public class UDFRawList extends UDF { + @SuppressWarnings("rawtypes") public List evaluate(Object o) { return Collections.singletonList("data1"); } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java index 4731b6eee85c..58c81f9945d7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFRawMap.java @@ -25,6 +25,7 @@ * UDF that returns a raw (non-parameterized) java Map. */ public class UDFRawMap extends UDF { + @SuppressWarnings("rawtypes") public Map evaluate(Object o) { return Collections.singletonMap("a", "1"); } diff --git a/sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/auto_join14_hadoop20-2-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/combine1-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/combine1-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/combine1-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/combine1-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/combine1-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/combine1-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/combine1-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/combine1-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/combine1-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/combine1-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/combine1-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/combine1-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/combine1-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/combine2-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/combine2-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/combine2-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/combine2-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/combine2-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/combine2-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/combine2-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/combine2-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/combine2-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/combine2-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/combine2-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/combine2-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/combine2-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/groupby1-3-d57ed4bbfee1ffaffaeba0a4be84c31d b/sql/hive/src/test/resources/golden/groupby1-3-c8478dac3497697b4375ee35118a5c3e similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1-3-d57ed4bbfee1ffaffaeba0a4be84c31d rename to sql/hive/src/test/resources/golden/groupby1-3-c8478dac3497697b4375ee35118a5c3e diff --git a/sql/hive/src/test/resources/golden/groupby1-5-dd7bf298b8c921355edd8665c6b0c168 b/sql/hive/src/test/resources/golden/groupby1-5-c9cee6382b64bd3d71177527961b8be2 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1-5-dd7bf298b8c921355edd8665c6b0c168 rename to sql/hive/src/test/resources/golden/groupby1-5-c9cee6382b64bd3d71177527961b8be2 diff --git a/sql/hive/src/test/resources/golden/groupby1_limit-0-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_limit-0-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_limit-0-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_limit-0-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby1_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby1_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby1_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby1_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_limit-0-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_limit-0-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_limit-0-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_limit-0-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby2_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby2_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby2_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby2_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby4_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby4_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby4_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby4_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby5_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby5_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby5_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby5_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby6_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby6_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby6_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby6_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map-3-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map-3-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map-3-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map-3-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_noskew-3-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_noskew-3-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_noskew-3-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_noskew-3-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby7_noskew_multi_single_reducer-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_map-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_map-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_map-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_map-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_map_skew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_map_skew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_map_skew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby8_noskew-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby8_noskew-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby8_noskew-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby8_noskew-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/groupby_map_ppr-2-83c59d378571a6e487aa20217bd87817 b/sql/hive/src/test/resources/golden/groupby_map_ppr-2-be2c0b32a02a1154bfdee1a52530f387 similarity index 100% rename from sql/hive/src/test/resources/golden/groupby_map_ppr-2-83c59d378571a6e487aa20217bd87817 rename to sql/hive/src/test/resources/golden/groupby_map_ppr-2-be2c0b32a02a1154bfdee1a52530f387 diff --git a/sql/hive/src/test/resources/golden/input12_hadoop20-0-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/input12_hadoop20-0-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/input12_hadoop20-0-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/input12_hadoop20-0-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-0-68975193b30cb34102b380e647d8d5f4 b/sql/hive/src/test/resources/golden/input_testsequencefile-0-dd959af1968381d0ed90178d349b01a7 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-0-68975193b30cb34102b380e647d8d5f4 rename to sql/hive/src/test/resources/golden/input_testsequencefile-0-dd959af1968381d0ed90178d349b01a7 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-1-1c0f3be2d837dee49312e0a80440447e b/sql/hive/src/test/resources/golden/input_testsequencefile-1-ddbb8d5e5dc0988bda96ac2b4aec8f94 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-1-1c0f3be2d837dee49312e0a80440447e rename to sql/hive/src/test/resources/golden/input_testsequencefile-1-ddbb8d5e5dc0988bda96ac2b4aec8f94 diff --git a/sql/hive/src/test/resources/golden/input_testsequencefile-5-3708198aac609695b22e19e89306034c b/sql/hive/src/test/resources/golden/input_testsequencefile-5-25715870c569b0f8c3d483e3a38b3199 similarity index 100% rename from sql/hive/src/test/resources/golden/input_testsequencefile-5-3708198aac609695b22e19e89306034c rename to sql/hive/src/test/resources/golden/input_testsequencefile-5-25715870c569b0f8c3d483e3a38b3199 diff --git a/sql/hive/src/test/resources/golden/join14_hadoop20-1-db1cd54a4cb36de2087605f32e41824f b/sql/hive/src/test/resources/golden/join14_hadoop20-1-2b9ccaa793eae0e73bf76335d3d6880 similarity index 100% rename from sql/hive/src/test/resources/golden/join14_hadoop20-1-db1cd54a4cb36de2087605f32e41824f rename to sql/hive/src/test/resources/golden/join14_hadoop20-1-2b9ccaa793eae0e73bf76335d3d6880 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-6b9861b999092f1ea4fa1fd27a666af6 similarity index 100% rename from sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c rename to sql/hive/src/test/resources/golden/leftsemijoin_mr-7-6b9861b999092f1ea4fa1fd27a666af6 diff --git a/sql/hive/src/test/resources/golden/merge2-2-c95dc367df88c9e5cf77157f29ba2daf b/sql/hive/src/test/resources/golden/merge2-2-6142f47d3fcdd4323162014d5eb35e07 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-2-c95dc367df88c9e5cf77157f29ba2daf rename to sql/hive/src/test/resources/golden/merge2-2-6142f47d3fcdd4323162014d5eb35e07 diff --git a/sql/hive/src/test/resources/golden/merge2-3-6e53a3ac93113f20db3a12f1dcf30e86 b/sql/hive/src/test/resources/golden/merge2-3-10266e3d5dd4c841c0d65030b1edba7c similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-3-6e53a3ac93113f20db3a12f1dcf30e86 rename to sql/hive/src/test/resources/golden/merge2-3-10266e3d5dd4c841c0d65030b1edba7c diff --git a/sql/hive/src/test/resources/golden/merge2-4-84967075baa3e56fff2a23f8ab9ba076 b/sql/hive/src/test/resources/golden/merge2-4-9cbd6d400fb6c3cd09010e3dbd76601 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-4-84967075baa3e56fff2a23f8ab9ba076 rename to sql/hive/src/test/resources/golden/merge2-4-9cbd6d400fb6c3cd09010e3dbd76601 diff --git a/sql/hive/src/test/resources/golden/merge2-5-2ee5d706fe3a3bcc38b795f6e94970ea b/sql/hive/src/test/resources/golden/merge2-5-1ba2d6f3bb3348da3fee7fab4f283f34 similarity index 100% rename from sql/hive/src/test/resources/golden/merge2-5-2ee5d706fe3a3bcc38b795f6e94970ea rename to sql/hive/src/test/resources/golden/merge2-5-1ba2d6f3bb3348da3fee7fab4f283f34 diff --git a/sql/hive/src/test/resources/golden/parallel-0-23a4feaede17467a8cc26e4d86ec30f9 b/sql/hive/src/test/resources/golden/parallel-0-6dc30e2de057022e63bd2a645fbec4c2 similarity index 100% rename from sql/hive/src/test/resources/golden/parallel-0-23a4feaede17467a8cc26e4d86ec30f9 rename to sql/hive/src/test/resources/golden/parallel-0-6dc30e2de057022e63bd2a645fbec4c2 diff --git a/sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-3708198aac609695b22e19e89306034c b/sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-25715870c569b0f8c3d483e3a38b3199 similarity index 100% rename from sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-3708198aac609695b22e19e89306034c rename to sql/hive/src/test/resources/golden/rcfile_lazydecompress-11-25715870c569b0f8c3d483e3a38b3199 diff --git a/sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-68975193b30cb34102b380e647d8d5f4 b/sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-dd959af1968381d0ed90178d349b01a7 similarity index 100% rename from sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-68975193b30cb34102b380e647d8d5f4 rename to sql/hive/src/test/resources/golden/rcfile_lazydecompress-5-dd959af1968381d0ed90178d349b01a7 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q index 235b7c1b3fcd..6a9a20f3207b 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_join14_hadoop20.q @@ -5,7 +5,7 @@ set hive.auto.convert.join = true; CREATE TABLE dest1(c1 INT, c2 STRING) STORED AS TEXTFILE; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; explain diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q index 877f8a50a0e3..87f6eca4dd4e 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket5.q @@ -4,7 +4,7 @@ set hive.enforce.sorting = true; set hive.exec.reducers.max = 1; set hive.merge.mapfiles = true; set hive.merge.mapredfiles = true; -set mapred.reduce.tasks = 2; +set mapreduce.job.reduces = 2; -- Tests that when a multi insert inserts into a bucketed table and a table which is not bucketed -- the bucketed table is not merged and the table which is not bucketed is diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q index 37ae6cc7adea..84fe3919d7a6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucket_num_reducers.q @@ -1,6 +1,6 @@ set hive.enforce.bucketing = true; set hive.exec.mode.local.auto=false; -set mapred.reduce.tasks = 10; +set mapreduce.job.reduces = 10; -- This test sets number of mapred tasks to 10 for a database with 50 buckets, -- and uses a post-hook to confirm that 10 tasks were created diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q index d2e12e82d4a2..ae72f98fa424 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/bucketizedhiveinputformat.q @@ -1,5 +1,5 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat; -set mapred.min.split.size = 64; +set mapreduce.input.fileinputformat.split.minsize = 64; CREATE TABLE T1(name STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q index 86abf0996057..5ecfc2172478 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine1.q @@ -1,11 +1,11 @@ set hive.exec.compress.output = true; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -set mapred.output.compression.codec=org.apache.hadoop.io.compress.GzipCodec; +set mapreduce.output.fileoutputformat.compress.codec=org.apache.hadoop.io.compress.GzipCodec; create table combine1_1(key string, value string) stored as textfile; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q index cfd9856f0868..acd0dd5e5bc9 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2.q @@ -1,10 +1,10 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; @@ -18,7 +18,7 @@ set hive.merge.smallfiles.avgsize=0; create table combine2(key string) partitioned by (value string); -- EXCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=256 and hive.merge.smallfiles.avgsize=0 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=256 and hive.merge.smallfiles.avgsize=0 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q index 8f9a59d49753..597d3ae479b9 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_hadoop20.q @@ -1,10 +1,10 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; @@ -17,7 +17,7 @@ set hive.merge.smallfiles.avgsize=0; create table combine2(key string) partitioned by (value string); -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=256 and hive.merge.smallfiles.avgsize=0 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=256 and hive.merge.smallfiles.avgsize=0 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q index f6090bb99b29..4f7174a1b636 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine2_win.q @@ -1,8 +1,8 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; set mapred.cache.shared.enabled=false; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q index c9afc91bb456..35dd442027b4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/combine3.q @@ -1,9 +1,9 @@ set hive.exec.compress.output = true; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; drop table combine_3_srcpart_seq_rc; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q index f348e5902263..5e51d11864dd 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/create_1.q @@ -1,4 +1,4 @@ -set fs.default.name=invalidscheme:///; +set fs.defaultFS=invalidscheme:///; CREATE TABLE table1 (a STRING, b STRING) STORED AS TEXTFILE; DESCRIBE table1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q index f39689de03a5..979c9072303c 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/ctas_hadoop20.q @@ -49,7 +49,7 @@ describe formatted nzhang_CTAS4; explain extended create table nzhang_ctas5 row format delimited fields terminated by ',' lines terminated by '\012' stored as textfile as select key, value from src sort by key, value limit 10; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; create table nzhang_ctas5 row format delimited fields terminated by ',' lines terminated by '\012' stored as textfile as select key, value from src sort by key, value limit 10; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q index 1275eab281f4..0d75857e54e5 100755 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1.q @@ -3,12 +3,12 @@ set hive.groupby.skewindata=true; CREATE TABLE dest_g1(key INT, value DOUBLE) STORED AS TEXTFILE; -set fs.default.name=invalidscheme:///; +set fs.defaultFS=invalidscheme:///; EXPLAIN FROM src INSERT OVERWRITE TABLE dest_g1 SELECT src.key, sum(substr(src.value,5)) GROUP BY src.key; -set fs.default.name=file:///; +set fs.defaultFS=file:///; FROM src INSERT OVERWRITE TABLE dest_g1 SELECT src.key, sum(substr(src.value,5)) GROUP BY src.key; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q index 55133332a866..bbb2859a9d45 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_limit.q @@ -1,4 +1,4 @@ -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q index dde37dfd4714..7883d948d067 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q index f346cb7e9014..a5ac3762ce79 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q index c587b5f658f6..6341eefb5043 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby1_noskew.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g1(key INT, value DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q index 30499248cac1..df4693446d6c 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_limit.q @@ -1,4 +1,4 @@ -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; EXPLAIN SELECT src.key, sum(substr(src.value,5)) FROM src GROUP BY src.key ORDER BY src.key LIMIT 5; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q index 794ec758e9ed..7b6e175c2df0 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q index 55d1a34b3c92..3aeae0d5c33d 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING, c3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q index 39a2a178e3a5..998156d05f99 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q index 6d7cb61e2d44..fab4f5d097f1 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g2(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q index b2450c9ea04e..9ef556cdc583 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_noskew_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest_g2(key STRING, c1 INT, c2 STRING, c3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q index 7ecc71dfab64..36ba5d89c0f7 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q index 50243beca9ef..6f0a9635a284 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE, c10 DOUBLE, c11 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q index 07d10c2d741d..64a49e2525ed 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q index d33f12c5744e..4fd98efd6ef4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q index 86d8986f1df7..85ee8ac43e52 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby3_noskew_multi_distinct.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 DOUBLE, c2 DOUBLE, c3 DOUBLE, c4 DOUBLE, c5 DOUBLE, c6 DOUBLE, c7 DOUBLE, c8 DOUBLE, c9 DOUBLE, c10 DOUBLE, c11 DOUBLE) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q index 8ecce23eb832..d71721875bbf 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q index eb2001c6b21b..d1ecba143d62 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q index a1ebf90aadfe..63530c262c14 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby4_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q index 4fd6445d7927..4418bbffec7a 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q index eccd45dd5b42..ef20dacf0599 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q index e96568b398d8..17b322b890ff 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby5_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q index ced122fae3f5..bef0eeee0e89 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q index 0d3727b05285..ee93b218ac78 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q index 466c13222f29..72fff08decf0 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby6_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q index 2b8c5db41ea9..75149b140415 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map.q @@ -1,7 +1,7 @@ set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q index 5895ed459984..7c7829aac2d6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_multi_single_reducer.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q index ee6d7bf83084..905986d417df 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q index 8c2308e5d75c..1f63453672a4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.multigroupby.singlereducer=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q index e673cc61622c..2ce57e98072f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby7_noskew_multi_single_reducer.q @@ -1,6 +1,6 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q index 0252e993363a..9def7d64721e 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q index b5e1f63a4525..788bc683697d 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_map_skew.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=true; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q index da85504ca18c..17885c56b3f1 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby8_noskew.q @@ -1,7 +1,7 @@ set hive.map.aggr=false; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, value STRING) STORED AS TEXTFILE; CREATE TABLE DEST2(key INT, value STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q index 4a199365cf96..9cb98aa909e1 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q index cb3ee8291861..841df75af18b 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_map_ppr_multi_distinct.q @@ -1,6 +1,6 @@ set hive.map.aggr=true; set hive.groupby.skewindata=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE dest1(key STRING, c1 INT, c2 STRING, C3 INT, c4 INT) STORED AS TEXTFILE; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q index 7401a9ca1d9b..cdf4bb1cac9d 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_1.q @@ -248,7 +248,7 @@ SELECT * FROM outputTbl4 ORDER BY key1, key2, key3; set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, cnt INT); CREATE TABLE DEST2(key INT, val STRING, cnt INT); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q index db0faa04da0e..1c23fad76eff 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_skew_1.q @@ -249,7 +249,7 @@ SELECT * FROM outputTbl4 ORDER BY key1, key2, key3; set hive.map.aggr=true; set hive.multigroupby.singlereducer=false; -set mapred.reduce.tasks=31; +set mapreduce.job.reduces=31; CREATE TABLE DEST1(key INT, cnt INT); CREATE TABLE DEST2(key INT, val STRING, cnt INT); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q index 94ba14802f01..996c9d99f0b9 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/hook_context_cs.q @@ -5,7 +5,7 @@ ALTER TABLE vcsc ADD partition (ds='dummy') location '${system:test.tmp.dir}/Ver set hive.exec.pre.hooks=org.apache.hadoop.hive.ql.hooks.VerifyContentSummaryCacheHook; SELECT a.c, b.c FROM vcsc a JOIN vcsc b ON a.ds = 'dummy' AND b.ds = 'dummy' AND a.c = b.c; -set mapred.job.tracker=local; +set mapreduce.jobtracker.address=local; set hive.exec.pre.hooks = ; set hive.exec.post.hooks=org.apache.hadoop.hive.ql.hooks.VerifyContentSummaryCacheHook; SELECT a.c, b.c FROM vcsc a JOIN vcsc b ON a.ds = 'dummy' AND b.ds = 'dummy' AND a.c = b.c; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q index 728b8cc4a949..5d3c6c43c640 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_dyn_part.q @@ -63,7 +63,7 @@ set hive.merge.mapredfiles=true; set hive.merge.smallfiles.avgsize=200; set hive.exec.compress.output=false; set hive.exec.dynamic.partition=true; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; -- Tests dynamic partitions where bucketing/sorting can be inferred, but some partitions are -- merged and some are moved. Currently neither should be bucketed or sorted, in the future, diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q index 41c1a13980cf..aa49b0dc64c4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_merge.q @@ -1,7 +1,7 @@ set hive.exec.infer.bucket.sort=true; set hive.exec.infer.bucket.sort.num.buckets.power.two=true; set hive.merge.mapredfiles=true; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; -- This tests inferring how data is bucketed/sorted from the operators in the reducer -- and populating that information in partitions' metadata. In particular, those cases diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q index 2255bdb34913..3a454f77bc4d 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/infer_bucket_sort_num_buckets.q @@ -1,7 +1,7 @@ set hive.exec.infer.bucket.sort=true; set hive.merge.mapfiles=false; set hive.merge.mapredfiles=false; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; CREATE TABLE test_table (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q index 318cd378db13..31e99e8d9464 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input12_hadoop20.q @@ -1,4 +1,4 @@ -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q index 29e9fae1da9e..362c164176a9 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input39_hadoop20.q @@ -15,7 +15,7 @@ select key, value from src; set hive.test.mode=true; set hive.mapred.mode=strict; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; explain @@ -24,7 +24,7 @@ select count(1) from t1 join t2 on t1.key=t2.key where t1.ds='1' and t2.ds='1'; select count(1) from t1 join t2 on t1.key=t2.key where t1.ds='1' and t2.ds='1'; set hive.test.mode=false; -set mapred.job.tracker; +set mapreduce.jobtracker.address; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q index d9926888cef9..2b16c5cd0864 100755 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/input_testsequencefile.q @@ -1,5 +1,5 @@ -set mapred.output.compress=true; -set mapred.output.compression.type=BLOCK; +set mapreduce.output.fileoutputformat.compress=true; +set mapreduce.output.fileoutputformat.compress.type=BLOCK; CREATE TABLE dest4_sequencefile(key INT, value STRING) STORED AS SEQUENCEFILE; @@ -10,5 +10,5 @@ INSERT OVERWRITE TABLE dest4_sequencefile SELECT src.key, src.value; FROM src INSERT OVERWRITE TABLE dest4_sequencefile SELECT src.key, src.value; -set mapred.output.compress=false; +set mapreduce.output.fileoutputformat.compress=false; SELECT dest4_sequencefile.* FROM dest4_sequencefile; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q index a12ef1afb055..b3d75b63bd40 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/join14_hadoop20.q @@ -2,7 +2,7 @@ CREATE TABLE dest1(c1 INT, c2 STRING) STORED AS TEXTFILE; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto=true; EXPLAIN diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q index c9ebe0e8fad1..d98247b63d34 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/leftsemijoin_mr.q @@ -9,7 +9,7 @@ SELECT * FROM T1; SELECT * FROM T2; set hive.auto.convert.join=false; -set mapred.reduce.tasks=2; +set mapreduce.job.reduces=2; set hive.join.emit.interval=100; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q index 8b77bd2fe19b..9189e7c0d1af 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/merge2.q @@ -1,9 +1,9 @@ set hive.merge.mapfiles=true; set hive.merge.mapredfiles=true; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; create table test1(key int, val int); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q index 872692567b37..dcb2a853bae5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_createas1.q @@ -1,5 +1,5 @@ -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE orc_createas1a; DROP TABLE orc_createas1b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q index 1f5f54ae19ee..93f8f519cf21 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_char.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q index c34be867e484..3a74de82a472 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_date.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q index a93590eacca0..82f68a9ae56b 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_decimal.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q index 0fecc664e46d..99f58cd73f79 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_ppd_varchar.q @@ -1,6 +1,6 @@ SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; create table newtypesorc(c char(10), v varchar(10), d decimal(5,3), da date) stored as orc tblproperties("orc.stripe.size"="16777216"); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q index 54eb23e776b8..9aa868f9d2f0 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/orc_split_elimination.q @@ -3,8 +3,8 @@ create table orc_split_elim (userid bigint, string1 string, subtype double, deci load data local inpath '../../data/files/orc_split_elim.orc' into table orc_split_elim; SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; -SET mapred.min.split.size=1000; -SET mapred.max.split.size=5000; +SET mapreduce.input.fileinputformat.split.minsize=1000; +SET mapreduce.input.fileinputformat.split.maxsize=5000; SET hive.optimize.index.filter=false; -- The above table will have 5 splits with the followings stats diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q index 03edeaadeef5..3ac60306551e 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel.q @@ -1,4 +1,4 @@ -set mapred.job.name='test_parallel'; +set mapreduce.job.name='test_parallel'; set hive.exec.parallel=true; set hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q index 73c394064484..777771f22763 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parallel_orderby.q @@ -2,7 +2,7 @@ create table src5 (key string, value string); load data local inpath '../../data/files/kv5.txt' into table src5; load data local inpath '../../data/files/kv5.txt' into table src5; -set mapred.reduce.tasks = 4; +set mapreduce.job.reduces = 4; set hive.optimize.sampling.orderby=true; set hive.optimize.sampling.orderby.percent=0.66f; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q index f36203724c15..14e13c56b1db 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_createas1.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_createas1a; DROP TABLE rcfile_createas1b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q index 7f55d10bd645..43a15a06f870 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_lazydecompress.q @@ -10,7 +10,7 @@ SELECT key, value FROM rcfileTableLazyDecompress where key > 238 and key < 400 O SELECT key, count(1) FROM rcfileTableLazyDecompress where key > 238 group by key ORDER BY key ASC; -set mapred.output.compress=true; +set mapreduce.output.fileoutputformat.compress=true; set hive.exec.compress.output=true; FROM src @@ -22,6 +22,6 @@ SELECT key, value FROM rcfileTableLazyDecompress where key > 238 and key < 400 O SELECT key, count(1) FROM rcfileTableLazyDecompress where key > 238 group by key ORDER BY key ASC; -set mapred.output.compress=false; +set mapreduce.output.fileoutputformat.compress=false; set hive.exec.compress.output=false; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q index 1f6f1bd251c2..25071579cb04 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge1.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=false; set hive.exec.dynamic.partition=true; -set mapred.max.split.size=100; +set mapreduce.input.fileinputformat.split.maxsize=100; set mapref.min.split.size=1; DROP TABLE rcfile_merge1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q index 215d5ebc4a25..15ffb90bf627 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge2.q @@ -1,7 +1,7 @@ set hive.merge.rcfile.block.level=true; set hive.exec.dynamic.partition=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge2a; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q index 39fbd2564664..787ab4a8d7fa 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge3.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge3a; DROP TABLE rcfile_merge3b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q index fe6df28566cf..77ac381c65bb 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/rcfile_merge4.q @@ -1,6 +1,6 @@ set hive.merge.rcfile.block.level=true; -set mapred.max.split.size=100; -set mapred.min.split.size=1; +set mapreduce.input.fileinputformat.split.maxsize=100; +set mapreduce.input.fileinputformat.split.minsize=1; DROP TABLE rcfile_merge3a; DROP TABLE rcfile_merge3b; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q index 12f2bcd46ec8..bf12ba5ed8e6 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook.q @@ -1,8 +1,8 @@ set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.exec.mode.local.auto=true; set hive.merge.smallfiles.avgsize=1; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q index 484e1fa617d8..5d1bd184d2ad 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/sample_islocalmode_hook_hadoop20.q @@ -1,15 +1,15 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.exec.mode.local.auto=true; set hive.merge.smallfiles.avgsize=1; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20, 0.20S) --- This test sets mapred.max.split.size=300 and hive.merge.smallfiles.avgsize=1 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=300 and hive.merge.smallfiles.avgsize=1 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a @@ -25,7 +25,7 @@ create table sih_src as select key, value from sih_i_part order by key, value; create table sih_src2 as select key, value from sih_src order by key, value; set hive.exec.post.hooks = org.apache.hadoop.hive.ql.hooks.VerifyIsLocalModeHook ; -set mapred.job.tracker=localhost:58; +set mapreduce.jobtracker.address=localhost:58; set hive.exec.mode.local.auto.input.files.max=1; -- Sample split, running locally limited by num tasks diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q index 952eaf72f10c..eb774f15829b 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/split_sample.q @@ -1,14 +1,14 @@ USE default; set hive.input.format=org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; -set mapred.max.split.size=300; -set mapred.min.split.size=300; -set mapred.min.split.size.per.node=300; -set mapred.min.split.size.per.rack=300; +set mapreduce.input.fileinputformat.split.maxsize=300; +set mapreduce.input.fileinputformat.split.minsize=300; +set mapreduce.input.fileinputformat.split.minsize.per.node=300; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300; set hive.merge.smallfiles.avgsize=1; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20) --- This test sets mapred.max.split.size=300 and hive.merge.smallfiles.avgsize=1 +-- This test sets mapreduce.input.fileinputformat.split.maxsize=300 and hive.merge.smallfiles.avgsize=1 -- in an attempt to force the generation of multiple splits and multiple output files. -- However, Hadoop 0.20 is incapable of generating splits smaller than the block size -- when using CombineFileInputFormat, so only one split is generated. This has a @@ -72,10 +72,10 @@ select t1.key as k1, t2.key as k from ss_src1 tablesample(80 percent) t1 full ou -- shrink last split explain select count(1) from ss_src2 tablesample(1 percent); -set mapred.max.split.size=300000; -set mapred.min.split.size=300000; -set mapred.min.split.size.per.node=300000; -set mapred.min.split.size.per.rack=300000; +set mapreduce.input.fileinputformat.split.maxsize=300000; +set mapreduce.input.fileinputformat.split.minsize=300000; +set mapreduce.input.fileinputformat.split.minsize.per.node=300000; +set mapreduce.input.fileinputformat.split.minsize.per.rack=300000; select count(1) from ss_src2 tablesample(1 percent); select count(1) from ss_src2 tablesample(50 percent); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q index cdf92e44cf67..caf359c9e6b4 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1.q @@ -2,13 +2,13 @@ set datanucleus.cache.collections=false; set hive.stats.autogather=false; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.20,0.20S) --- This test uses mapred.max.split.size/mapred.max.split.size for controlling +-- This test uses mapreduce.input.fileinputformat.split.maxsize/mapred.max.split.size for controlling -- number of input splits, which is not effective in hive 0.20. -- stats_partscan_1_23.q is the same test with this but has different result. diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q index 1e5f360b20cb..07694891fd6f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/stats_partscan_1_23.q @@ -2,13 +2,13 @@ set datanucleus.cache.collections=false; set hive.stats.autogather=false; set hive.exec.dynamic.partition=true; set hive.exec.dynamic.partition.mode=nonstrict; -set mapred.min.split.size=256; -set mapred.min.split.size.per.node=256; -set mapred.min.split.size.per.rack=256; -set mapred.max.split.size=256; +set mapreduce.input.fileinputformat.split.minsize=256; +set mapreduce.input.fileinputformat.split.minsize.per.node=256; +set mapreduce.input.fileinputformat.split.minsize.per.rack=256; +set mapreduce.input.fileinputformat.split.maxsize=256; -- INCLUDE_HADOOP_MAJOR_VERSIONS(0.23) --- This test uses mapred.max.split.size/mapred.max.split.size for controlling +-- This test uses mapreduce.input.fileinputformat.split.maxsize/mapred.max.split.size for controlling -- number of input splits. -- stats_partscan_1.q is the same test with this but has different result. diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q index f065385688a1..5b5d669a7c12 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_context_ngrams.q @@ -1,6 +1,6 @@ CREATE TABLE kafka (contents STRING); LOAD DATA LOCAL INPATH '../../data/files/text-en.txt' INTO TABLE kafka; -set mapred.reduce.tasks=1; +set mapreduce.job.reduces=1; set hive.exec.reducers.max=1; SELECT context_ngrams(sentences(lower(contents)), array(null), 100, 1000).estfrequency FROM kafka; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q index 6a2fde52e42f..39e6e30ae694 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/udaf_ngrams.q @@ -1,6 +1,6 @@ CREATE TABLE kafka (contents STRING); LOAD DATA LOCAL INPATH '../../data/files/text-en.txt' INTO TABLE kafka; -set mapred.reduce.tasks=1; +set mapreduce.job.reduces=1; set hive.exec.reducers.max=1; SELECT ngrams(sentences(lower(contents)), 1, 100, 1000).estfrequency FROM kafka; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3871b3d78588..8ccc2b7527f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -204,13 +204,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) // Append new data. table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) - // We are still using the old data. assertCached(table("refreshTable")) - checkAnswer( - table("refreshTable"), - table("src").collect()) - // Refresh the table. - sql("REFRESH TABLE refreshTable") + // We are using the new data. assertCached(table("refreshTable")) checkAnswer( @@ -249,13 +244,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) // Append new data. table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) - // We are still using the old data. assertCached(table("refreshTable")) - checkAnswer( - table("refreshTable"), - table("src").collect()) - // Refresh the table. - sql(s"REFRESH ${tempPath.toString}") + // We are using the new data. assertCached(table("refreshTable")) checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index a60c210b04c8..4349f1aa23be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -52,7 +52,7 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { test("list partitions by filter") { val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1)) + val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") assert(selectedPartitions.length == 1) assert(selectedPartitions.head.spec == part1.spec) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 71ce5a7c4a15..d6999af84eac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -284,19 +284,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } - test("Reject partitioning that does not match table") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) - .toDF("id", "data", "part") - - intercept[AnalysisException] { - // cannot partition by 2 fields when there is only one in the table definition - data.write.partitionBy("part", "data").insertInto("partitioned") - } - } - } - test("Test partition mode = strict") { withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e951bbe1dcbf..03ea0c8c7768 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -511,9 +511,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create external table") { withTempPath { tempPath => withTable("savedJsonTable", "createdJsonTable") { - val df = read.json(sparkContext.parallelize((1 to 10).map { i => + val df = read.json((1 to 10).map { i => s"""{ "a": $i, "b": "str$i" }""" - })) + }.toDS()) withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala deleted file mode 100644 index 91ff711445e8..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} - -class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - test("makeCopy and toJSON should work") { - val table = CatalogTable( - identifier = TableIdentifier("test", Some("db")), - tableType = CatalogTableType.VIEW, - storage = CatalogStorageFormat.empty, - schema = StructType(StructField("a", IntegerType, true) :: Nil)) - val relation = MetastoreRelation("db", "test")(table, null) - - // No exception should be thrown - relation.makeCopy(Array("db", "test")) - // No exception should be thrown - relation.toJSON - } - - test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { - withTable("bar") { - withTempView("foo") { - sql("select 0 as id").createOrReplaceTempView("foo") - // If we optimize the query in CTAS more than once, the following saveAsTable will fail - // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` - sql("CREATE TABLE bar AS SELECT * FROM foo group by id") - checkAnswer(spark.table("bar"), Row(0) :: Nil) - val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) - assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") - } - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index b792a168a4f9..50506197b313 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -411,4 +411,15 @@ class PartitionedTablePerfStatsSuite } } } + + test("resolveRelation for a FileFormat DataSource without userSchema scan filesystem only once") { + withTempDir { dir => + import spark.implicits._ + Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) + HiveCatalogMetrics.reset() + spark.read.parquet(dir.getAbsolutePath) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e2fcd2fd41fa..962998ea6fb6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogStatistics +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -33,52 +33,46 @@ import org.apache.spark.sql.types._ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { - test("MetastoreRelations fallback to HDFS for size estimation") { - val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled - try { - withTempDir { tempDir => - - // EXTERNAL OpenCSVSerde table pointing to LOCATION - - val file1 = new File(tempDir + "/data1") - val writer1 = new PrintWriter(file1) - writer1.write("1,2") - writer1.close() - - val file2 = new File(tempDir + "/data2") - val writer2 = new PrintWriter(file2) - writer2.write("1,2") - writer2.close() - - sql( - s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) - ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' - WITH SERDEPROPERTIES ( - \"separatorChar\" = \",\", - \"quoteChar\" = \"\\\"\", - \"escapeChar\" = \"\\\\\") - LOCATION '${tempDir.toURI}' - """) - - spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, true) - - val relation = spark.table("csv_table").queryExecution.analyzed.children.head - .asInstanceOf[MetastoreRelation] - - val properties = relation.hiveQlTable.getParameters - assert(properties.get("totalSize").toLong <= 0, "external table totalSize must be <= 0") - assert(properties.get("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - - val sizeInBytes = relation.stats(conf).sizeInBytes - assert(sizeInBytes === BigInt(file1.length() + file2.length())) + test("Hive serde tables should fallback to HDFS for size estimation") { + withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") { + withTable("csv_table") { + withTempDir { tempDir => + // EXTERNAL OpenCSVSerde table pointing to LOCATION + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s""" + |CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES ( + |\"separatorChar\" = \",\", + |\"quoteChar\" = \"\\\"\", + |\"escapeChar\" = \"\\\\\") + |LOCATION '${tempDir.toURI}'""".stripMargin) + + val relation = spark.table("csv_table").queryExecution.analyzed.children.head + .asInstanceOf[CatalogRelation] + + val properties = relation.tableMeta.properties + assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") + assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") + + val sizeInBytes = relation.stats(conf).sizeInBytes + assert(sizeInBytes === BigInt(file1.length() + file2.length())) + } } - } finally { - spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, enableFallBackToHdfsForStats) - sql("DROP TABLE csv_table ") } } - test("analyze MetastoreRelations") { + test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes @@ -152,9 +146,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } private def checkTableStats( - stats: Option[CatalogStatistics], + tableName: String, hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Unit = { + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { assert(stats.isDefined) assert(stats.get.sizeInBytes > 0) @@ -162,26 +158,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } else { assert(stats.isEmpty) } - } - private def checkTableStats( - tableName: String, - isDataSourceTable: Boolean, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val df = sql(s"SELECT * FROM $tableName") - val stats = df.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats - case rel: LogicalRelation => - checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head + stats } test("test table-level statistics for hive tables created in HiveExternalCatalog") { @@ -192,25 +170,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") checkTableStats( textTable, - isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") checkTableStats( textTable, - isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats1 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats2 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats2 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -221,25 +197,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats1 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // when the total size is not changed, the old row count is kept - val fetchedStats2 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) + val fetchedStats2 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // update total size and remove the old and invalid row count - val fetchedStats3 = checkTableStats( - textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats3 = + checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) } } - test("test statistics of LogicalRelation converted from MetastoreRelation") { + test("test statistics of LogicalRelation converted from Hive serde tables") { val parquetTable = "parquetTable" val orcTable = "orcTable" withTable(parquetTable, orcTable) { @@ -251,21 +227,14 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(500)) + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkTableStats( - orcTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(orcTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats( - orcTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) + checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } } } @@ -385,27 +354,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Add a filter to avoid creating too many partitions sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats1 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + val fetchedStats2 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(20)) + val fetchedStats3 = + checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(20)) assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -426,11 +391,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) dfNoCols.write.format("json").saveAsTable(table_no_cols) sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkTableStats( - table_no_cols, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(10)) + checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) } } @@ -478,10 +439,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(statsAfterUpdate.rowCount == Some(2)) } - test("estimates the size of a test MetastoreRelation") { + test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") - val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => - mr.stats(conf).sizeInBytes + val sizes = df.queryExecution.analyzed.collect { + case relation: CatalogRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -533,7 +494,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto after() } - /** Tests for MetastoreRelation */ + /** Tests for Hive serde tables */ val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238")) mkTest( @@ -541,7 +502,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto () => (), metastoreQuery, metastoreAnswer, - implicitly[ClassTag[MetastoreRelation]] + implicitly[ClassTag[CatalogRelation]] ) } @@ -555,9 +516,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass - .isAssignableFrom(r.getClass) => - r.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index 591a968c8284..e85ea5a59427 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -35,22 +35,26 @@ private[client] class HiveClientBuilder { Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) } - private def buildConf() = { + private def buildConf(extraConf: Map[String, String]) = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() metastorePath.delete() - Map( + extraConf ++ Map( "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", "hive.metastore.warehouse.dir" -> warehousePath.toString) } - def buildClient(version: String, hadoopConf: Configuration): HiveClient = { + // for testing only + def buildClient( + version: String, + hadoopConf: Configuration, + extraConf: Map[String, String] = Map.empty): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = sparkConf, hadoopConf = hadoopConf, - config = buildConf(), + config = buildConf(extraConf), ivyPath = ivyPath).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index fe14824cf096..d61d10bf869e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -88,7 +88,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") private var client: HiveClient = null @@ -98,7 +98,12 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - client = buildClient(version, hadoopConf) + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 + // For details, see the JIRA HIVE-6113 + if (version == "2.0") { + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + } + client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) } def table(database: String, tableName: String): CatalogTable = { @@ -175,7 +180,6 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w emptyDir, tableName = "src", replace = false, - holdDDLTime = false, isSrcLocal = false) } @@ -313,7 +317,6 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w "src_part", partSpec, replace = false, - holdDDLTime = false, inheritTableSpecs = false, isSrcLocal = false) } @@ -329,8 +332,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w "src_part", partSpec, replace = false, - numDP = 1, - holdDDLTime = false) + numDP = 1) } test(s"$version: renamePartitions") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f3151d52f20a..536ca8fd9d45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -385,7 +385,7 @@ abstract class HiveComparisonTest // also print out the query plans and results for those. val computedTablesMessages: String = try { val tablesRead = new TestHiveQueryExecution(query).executedPlan.collect { - case ts: HiveTableScanExec => ts.relation.tableName + case ts: HiveTableScanExec => ts.relation.tableMeta.identifier }.toSet TestHive.reset() @@ -393,7 +393,7 @@ abstract class HiveComparisonTest executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { case (q, e) => e.analyzed.collect { - case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + case i: InsertIntoHiveTable if tablesRead contains i.table.identifier => (q, e, i) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c04b9ee0f2cd..81ae5b7bdb67 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1570,7 +1570,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.getCanonicalPath}'") + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } @@ -1587,4 +1587,103 @@ class HiveDDLSuite } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == dir.getAbsolutePath) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == dir.getAbsolutePath) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + + test(s"CTAS for external hive table with a $tcName location") { + withTable("t", "t1") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING hive + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING hive + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f9751e3d5f2e..8a37bc3665d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,6 +26,20 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("show cost in explain command") { + // Only has sizeInBytes before ANALYZE command + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes") + checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount") + + // Has both sizeInBytes and rowCount after ANALYZE command + sql("ANALYZE TABLE src COMPUTE STATISTICS") + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes", "rowCount") + + // No cost information + checkKeywordsNotExist(sql("EXPLAIN SELECT * FROM src "), "sizeInBytes", "rowCount") + } test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), @@ -79,8 +93,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempView("jt") { - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - spark.read.json(rdd).createOrReplaceTempView("jt") + val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() + spark.read.json(ds).createOrReplaceTempView("jt") val outputs = sql( s""" |EXPLAIN EXTENDED diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b2f19d775395..ce92fbf34942 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -31,15 +31,15 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { - read.json(sparkContext.makeRDD( - """{"a": [{"a": {"a": 1}}]}""" :: Nil)).createOrReplaceTempView("nested") + read.json(Seq("""{"a": [{"a": {"a": 1}}]}""").toDS()) + .createOrReplaceTempView("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - read.json(sparkContext.makeRDD( - """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).createOrReplaceTempView("nested") + read.json(Seq("""{"a": [{"b": 1, "B": 2}]}""").toDS()) + .createOrReplaceTempView("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error val exception = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 5c460d25f372..90e037e29279 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -95,8 +94,7 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH private def checkNumScannedPartitions(stmt: String, expectedNumParts: Int): Unit = { val plan = sql(stmt).queryExecution.sparkPlan val numPartitions = plan.collectFirst { - case p: HiveTableScanExec => - p.relation.getHiveQlPartitions(p.partitionPruningPred).length + case p: HiveTableScanExec => p.rawPartitions.length }.getOrElse(0) assert(numPartitions == expectedNumParts) } @@ -170,11 +168,11 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH s""" |SELECT * FROM $table """.stripMargin).queryExecution.sparkPlan - val relation = plan.collectFirst { - case p: HiveTableScanExec => p.relation + val scan = plan.collectFirst { + case p: HiveTableScanExec => p }.get - val tableCols = relation.hiveQlTable.getCols - relation.getHiveQlPartitions().foreach(p => assert(p.getCols.size == tableCols.size)) + val numDataCols = scan.relation.dataCols.length + scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 24df73b40ea0..d535bef4cc78 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -153,8 +153,8 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScanExec(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.catalogTable.partitionColumnNames.nonEmpty) { - p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + val partValues = if (relation.isPartitioned) { + p.prunePartitions(p.rawPartitions).map(_.getValues) } else { Seq.empty } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index faed8b504649..ef2d451e6b2d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,12 +28,12 @@ import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -526,7 +526,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case LogicalRelation(r: HadoopFsRelation, _, _) => if (!isDataSourceTable) { fail( - s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + s"${classOf[CatalogRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") } userSpecifiedLocation match { @@ -536,15 +536,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } assert(catalogTable.provider.get === format) - case r: MetastoreRelation => + case r: CatalogRelation => if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + - s"${classOf[MetastoreRelation].getCanonicalName}.") + s"${classOf[CatalogRelation].getCanonicalName}.") } userSpecifiedLocation match { case Some(location) => - assert(r.catalogTable.location === location) + assert(r.tableMeta.location === location) case None => // OK. } // Also make sure that the format and serde are as desired. @@ -973,30 +973,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-4296 Grouping field with Hive UDF as sub expression") { - val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - read.json(rdd).createOrReplaceTempView("data") + val ds = Seq("""{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - read.json(rdd).createOrReplaceTempView("data") + read.json(ds).createOrReplaceTempView("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") } test("resolve udtf in projection #1") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } test("resolve udtf in projection #2") { - val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { @@ -1010,8 +1010,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { - val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - read.json(rdd).createOrReplaceTempView("data") + val ds = Seq("""{"a": "1", "b":"1"}""").toDS() + read.json(ds).createOrReplaceTempView("data") checkAnswer( sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) @@ -1024,13 +1024,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).createOrReplaceTempView("data") + val ds = (1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""").toDS() + read.json(ds).createOrReplaceTempView("data") withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: MetastoreRelation, _) => // OK + case SubqueryAlias(_, r: CatalogRelation, _) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } @@ -1262,9 +1262,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-9371: fix the support for special chars in column names for hive context") { - read.json(sparkContext.makeRDD( - """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) - .createOrReplaceTempView("t") + val ds = Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS() + read.json(ds).createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } @@ -2044,4 +2043,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("SELECT * FROM foo group by id").toDF().write.format("hive").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 9fa1fb931d76..38a5477796a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -26,8 +26,9 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} -import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf @@ -473,7 +474,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } else { queryExecution.analyzed.collectFirst { - case _: MetastoreRelation => () + case _: CatalogRelation => () }.getOrElse { fail(s"Expecting no conversion from orc to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 59ea8916efae..11dda5425cf9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -162,13 +162,16 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE hive_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - "INSERT INTO TABLE hive_orc SELECT 'a', 'b', 'c' FROM (SELECT 1) t") + """INSERT INTO TABLE hive_orc + |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) + |FROM (SELECT 1) t""".stripMargin) // We create a different table in Spark using the same schema which points to // the same location. @@ -177,10 +180,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE spark_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc |LOCATION '$uri'""".stripMargin) - val result = Row("a", "b ", "c") + val result = Row("a", "b ", "c", Seq("d ")) checkAnswer(spark.table("hive_orc"), result) checkAnswer(spark.table("spark_orc"), result) } finally { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1a1b2571b67b..3512c4a89031 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,8 +21,8 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.execution.DataSourceScanExec -import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.execution.HiveTableScanExec import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -806,7 +806,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } else { queryExecution.analyzed.collectFirst { - case _: MetastoreRelation => + case _: CatalogRelation => }.getOrElse { fail(s"Expecting no conversion from parquet to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala new file mode 100644 index 000000000000..f277f99805a4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +class BucketedReadWithHiveSupportSuite extends BucketedReadSuite with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala new file mode 100644 index 000000000000..454e2f65d5d8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +class BucketedWriteWithHiveSupportSuite extends BucketedWriteSuite with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") + } + + override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc") +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala index e64830a9459b..aea75d5a9c8d 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -19,14 +19,14 @@ package org.apache.spark.status.api.v1.streaming import javax.ws.rs.{Path, PathParam} -import org.apache.spark.status.api.v1.UIRootFromServletContext +import org.apache.spark.status.api.v1.ApiRequestContext @Path("/v1") -private[v1] class ApiStreamingApp extends UIRootFromServletContext { +private[v1] class ApiStreamingApp extends ApiRequestContext { @Path("applications/{appId}/streaming") def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ApiStreamingRootResource(ui) } } @@ -35,7 +35,7 @@ private[v1] class ApiStreamingApp extends UIRootFromServletContext { def getStreamingRoot( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ApiStreamingRootResource(ui) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 0a4c141e5be3..a34f6c73fea8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -435,13 +435,12 @@ class StreamingContext private[streaming] ( conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() require(bytes.length == recordLength, "Byte array does not have correct length. " + s"${bytes.length} did not equal recordLength: $recordLength") bytes } - data } /** diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index cb8ed83e5a49..b1367b8f2aed 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -145,8 +145,8 @@ private void testOperation( List>> expectedStateSnapshots) { int numBatches = expectedOutputs.size(); JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); - JavaMapWithStateDStream mapWithStateDStream = - JavaPairDStream.fromJavaDStream(inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec); + JavaMapWithStateDStream mapWithStateDStream = JavaPairDStream.fromJavaDStream( + inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec); List> collectedOutputs = Collections.synchronizedList(new ArrayList>()); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 9948a4074cdc..80513de4ee11 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -20,10 +20,13 @@ import java.io.Serializable; import java.util.*; +import org.apache.spark.api.java.function.Function3; +import org.apache.spark.api.java.function.Function4; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.JavaTestUtils; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.State; import org.apache.spark.streaming.StateSpec; import org.apache.spark.streaming.Time; import scala.Tuple2; @@ -142,8 +145,8 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow((x, y) -> x + y, - (x, y) -> x - y, new Duration(2000), new Duration(1000)); + JavaDStream reducedWindowed = stream.reduceByWindow( + (x, y) -> x + y, (x, y) -> x - y, new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -850,36 +853,44 @@ public void testMapWithStateAPI() { JavaPairRDD initialRDD = null; JavaPairDStream wordsDstream = null; + Function4, State, Optional> mapFn = + (time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }; + JavaMapWithStateDStream stateDstream = - wordsDstream.mapWithState( - StateSpec.function((time, key, value, state) -> { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return Optional.of(2.0); - }).initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); + wordsDstream.mapWithState( + StateSpec.function(mapFn) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + Function3, State, Double> mapFn2 = + (key, value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }; + JavaMapWithStateDStream stateDstream2 = - wordsDstream.mapWithState( - StateSpec.function((key, value, state) -> { - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return 2.0; - }).initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); + wordsDstream.mapWithState( + StateSpec.function(mapFn2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); JavaPairDStream mappedDStream = stateDstream2.stateSnapshots(); } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index b966cbdca076..96f8d9593d63 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -29,7 +29,6 @@ import org.apache.spark.streaming.Seconds; import org.apache.spark.streaming.StreamingContextState; import org.apache.spark.streaming.StreamingContextSuite; -import org.apache.spark.streaming.Time; import scala.Tuple2; import org.apache.hadoop.conf.Configuration; @@ -608,7 +607,8 @@ public void testFlatMap() { Arrays.asList("a","t","h","l","e","t","i","c","s")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(x -> Arrays.asList(x.split("(?!^)")).iterator()); + JavaDStream flatMapped = + stream.flatMap(x -> Arrays.asList(x.split("(?!^)")).iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1314,7 +1314,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 6fb50a405271..b5d36a36513a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -84,7 +84,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray + val output = outputQueue.asScala.flatten.toArray assert(output.length === expectedOutput.size) for (i <- output.indices) { assert(output(i) === expectedOutput(i)) @@ -155,14 +155,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // not enough to trigger a batch clock.advance(batchDuration.milliseconds / 2) - val input = Seq(1, 2, 3, 4, 5) - input.foreach { i => + val numCopies = 3 + val input = Array[Byte](1, 2, 3, 4, 5) + for (i <- 0 until numCopies) { Thread.sleep(batchDuration.milliseconds) val file = new File(testDir, i.toString) - Files.write(Array[Byte](i.toByte), file) + Files.write(input.map(b => (b + i).toByte), file) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) - logInfo("Created file " + file) + logInfo(s"Created file $file") // Advance the clock after creating the file to avoid a race when // setting its modification time clock.advance(batchDuration.milliseconds) @@ -170,10 +171,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(batchCounter.getNumCompletedBatches === i) } } - - val expectedOutput = input.map(i => i.toByte) - val obtainedOutput = outputQueue.asScala.flatten.toList.map(i => i(0).toByte) - assert(obtainedOutput.toSeq === expectedOutput) + val obtainedOutput = outputQueue.asScala.map(_.flatten).toSeq + for (i <- obtainedOutput.indices) { + assert(obtainedOutput(i) === input.map(b => (b + i).toByte)) + } } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -258,7 +259,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false val outputQueue = new ConcurrentLinkedQueue[Seq[Long]] - def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) + def output: Iterable[Long] = outputQueue.asScala.flatten // set up the network stream using the test receiver withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>