diff --git a/LICENSE b/LICENSE index d68609cc28733..7950dd6ceb6db 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.3 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 267a38c21530b..9cd6269f9a8f7 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -3,7 +3,7 @@ importFrom("methods", "setGeneric", "setMethod", "setOldClass") importFrom("methods", "is", "new", "signature", "show") importFrom("stats", "gaussian", "setNames") -importFrom("utils", "download.file", "packageVersion", "untar") +importFrom("utils", "download.file", "object.size", "packageVersion", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -43,7 +43,9 @@ exportMethods("glm", "spark.isoreg", "spark.gaussianMixture", "spark.als", - "spark.kstest") + "spark.kstest", + "spark.logit", + "spark.randomForest") # Job group lifecycle management methods export("setJobGroup", @@ -71,6 +73,7 @@ exportMethods("arrange", "covar_samp", "covar_pop", "createOrReplaceTempView", + "crossJoin", "crosstab", "dapply", "dapplyCollect", @@ -123,6 +126,7 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "storageLevel", "subset", "summarize", "summary", @@ -347,7 +351,9 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.RandomForestRegressionModel", + "print.summary.RandomForestClassificationModel") export("structField", "structField.jobj", @@ -372,6 +378,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.RandomForestRegressionModel) +S3method(print, summary.RandomForestClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 801d2ed4e7500..1df8bbf9fe604 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -365,7 +365,7 @@ setMethod("colnames<-", # Check if the column names have . in it if (any(regexec(".", value, fixed = TRUE)[[1]][1] != -1)) { - stop("Colum names cannot contain the '.' symbol.") + stop("Column names cannot contain the '.' symbol.") } sdf <- callJMethod(x@sdf, "toDF", as.list(value)) @@ -633,7 +633,7 @@ setMethod("persist", #' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions -#' @rdname unpersist-methods +#' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist #' @export @@ -654,6 +654,32 @@ setMethod("unpersist", x }) +#' StorageLevel +#' +#' Get storagelevel of this SparkDataFrame. +#' +#' @param x the SparkDataFrame to get the storageLevel. +#' +#' @family SparkDataFrame functions +#' @rdname storageLevel +#' @aliases storageLevel,SparkDataFrame-method +#' @name storageLevel +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' persist(df, "MEMORY_AND_DISK") +#' storageLevel(df) +#'} +#' @note storageLevel since 2.1.0 +setMethod("storageLevel", + signature(x = "SparkDataFrame"), + function(x) { + storageLevelToString(callJMethod(x@sdf, "storageLevel")) + }) + #' Repartition #' #' The following options for repartition are possible: @@ -735,7 +761,8 @@ setMethod("toJSON", #' Save the contents of SparkDataFrame as a JSON file #' -#' Save the contents of a SparkDataFrame as a JSON file (one object per line). Files written out +#' Save the contents of a SparkDataFrame as a JSON file (\href{http://jsonlines.org/}{ +#' JSON Lines text format or newline-delimited JSON}). Files written out #' with this method can be read back in as a SparkDataFrame using read.json(). #' #' @param x A SparkDataFrame @@ -2271,12 +2298,13 @@ setMethod("dropDuplicates", #' Join #' -#' Join two SparkDataFrames based on the given join expression. +#' Joins two SparkDataFrames based on the given join expression. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a -#' Column expression. If joinExpr is omitted, join() will perform a Cartesian join +#' Column expression. If joinExpr is omitted, the default, inner join is attempted and an error is +#' thrown if it would be a Cartesian Product. For Cartesian join, use crossJoin instead. #' @param joinType The type of join to perform. The following join types are available: #' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', #' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". @@ -2285,23 +2313,24 @@ setMethod("dropDuplicates", #' @aliases join,SparkDataFrame,SparkDataFrame-method #' @rdname join #' @name join -#' @seealso \link{merge} +#' @seealso \link{merge} \link{crossJoin} #' @export #' @examples #'\dontrun{ #' sparkR.session() #' df1 <- read.json(path) #' df2 <- read.json(path2) -#' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' join(df1, df2) # Attempts an inner join #' } #' @note join since 1.4.0 setMethod("join", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y, joinExpr = NULL, joinType = NULL) { if (is.null(joinExpr)) { - sdf <- callJMethod(x@sdf, "crossJoin", y@sdf) + # this may not fail until the planner checks for Cartesian join later on. + sdf <- callJMethod(x@sdf, "join", y@sdf) } else { if (class(joinExpr) != "Column") stop("joinExpr must be a Column") if (is.null(joinType)) { @@ -2322,22 +2351,52 @@ setMethod("join", dataFrame(sdf) }) +#' CrossJoin +#' +#' Returns Cartesian Product on two SparkDataFrames. +#' +#' @param x A SparkDataFrame +#' @param y A SparkDataFrame +#' @return A SparkDataFrame containing the result of the join operation. +#' @family SparkDataFrame functions +#' @aliases crossJoin,SparkDataFrame,SparkDataFrame-method +#' @rdname crossJoin +#' @name crossJoin +#' @seealso \link{merge} \link{join} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' crossJoin(df1, df2) # Performs a Cartesian +#' } +#' @note crossJoin since 2.1.0 +setMethod("crossJoin", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + sdf <- callJMethod(x@sdf, "crossJoin", y@sdf) + dataFrame(sdf) + }) + #' Merges two data frames #' #' @name merge -#' @param x the first data frame to be joined -#' @param y the second data frame to be joined +#' @param x the first data frame to be joined. +#' @param y the second data frame to be joined. #' @param by a character vector specifying the join columns. If by is not #' specified, the common column names in \code{x} and \code{y} will be used. +#' If by or both by.x and by.y are explicitly set to NULL or of length 0, the Cartesian +#' Product of x and y will be returned. #' @param by.x a character vector specifying the joining columns for x. #' @param by.y a character vector specifying the joining columns for y. #' @param all a boolean value setting \code{all.x} and \code{all.y} #' if any of them are unset. #' @param all.x a boolean value indicating whether all the rows in x should -#' be including in the join +#' be including in the join. #' @param all.y a boolean value indicating whether all the rows in y should -#' be including in the join -#' @param sort a logical argument indicating whether the resulting columns should be sorted +#' be including in the join. +#' @param sort a logical argument indicating whether the resulting columns should be sorted. #' @param suffixes a string vector of length 2 used to make colnames of #' \code{x} and \code{y} unique. #' The first element is appended to each colname of \code{x}. @@ -2351,20 +2410,21 @@ setMethod("join", #' @family SparkDataFrame functions #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge -#' @seealso \link{join} +#' @seealso \link{join} \link{crossJoin} #' @export #' @examples #'\dontrun{ #' sparkR.session() #' df1 <- read.json(path) #' df2 <- read.json(path2) -#' merge(df1, df2) # Performs a Cartesian +#' merge(df1, df2) # Performs an inner join by common columns #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE, all.y = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all = TRUE, sort = FALSE) #' merge(df1, df2, by = "col1", all = TRUE, suffixes = c("-X", "-Y")) +#' merge(df1, df2, by = NULL) # Performs a Cartesian join #' } #' @note merge since 1.5.0 setMethod("merge", @@ -2401,7 +2461,7 @@ setMethod("merge", joinY <- by } else { # if by or both by.x and by.y have length 0, use Cartesian Product - joinRes <- join(x, y) + joinRes <- crossJoin(x, y) return (joinRes) } diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 6cd0704003f1a..0f1162fec1df9 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -261,7 +261,7 @@ setMethod("persistRDD", #' cache(rdd) # rdd@@env$isCached == TRUE #' unpersistRDD(rdd) # rdd@@env$isCached == FALSE #'} -#' @rdname unpersist-methods +#' @rdname unpersist #' @aliases unpersist,RDD-method #' @noRd setMethod("unpersistRDD", diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 0d6a229e63455..216ca51666ba8 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -324,7 +324,8 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a SparkDataFrame +#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ), returning the result as a SparkDataFrame #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 03e70bb2cb82e..0a789e6c379d6 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) { conn <- get(".sparkRCon", .sparkREnv) writeBin(requestMessage, conn) - # TODO: check the status code to output error information returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + + # Backend will send +1 as keep alive value to prevent various connection timeouts + # on very long running jobs. See spark.r.heartBeatInterval + while (returnStatus == 1) { + returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + } + + readObject(conn) +} + +# Helper function to check for returned errors and print appropriate error message to user +handleErrors <- function(returnStatus, conn) { if (length(returnStatus) == 0) { stop("No status is returned. Java SparkR backend might have failed.") } - if (returnStatus != 0) { + + # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors. + if (returnStatus < 0) { stop(readString(conn)) } - readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 2d341d836c133..9d82814211bc5 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout = 6000) { +connectBackend <- function(hostname, port, timeout) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 810aea9017743..0271b26a10a90 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -468,6 +468,10 @@ setGeneric("createOrReplaceTempView", standardGeneric("createOrReplaceTempView") }) +# @rdname crossJoin +# @export +setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) @@ -687,6 +691,10 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) +# @rdname storageLevel +# @export +setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) + #' @rdname subset #' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) @@ -711,7 +719,7 @@ setGeneric("union", function(x, y) { standardGeneric("union") }) #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) -#' @rdname unpersist-methods +#' @rdname unpersist #' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) @@ -1302,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) -#' @rdname spark.glm +###################### Spark.ML Methods ########################## + +#' @rdname fitted #' @export -setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +setGeneric("fitted") #' @param x,y For \code{glm}: logical values indicating whether the response vector #' and model matrix used in the fitting process should be returned as @@ -1324,13 +1334,38 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @export setGeneric("rbind", signature = "...") +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) + +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) + +#' @rdname spark.glm +#' @export +setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) + +#' @rdname spark.isoreg +#' @export +setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) + #' @rdname spark.kmeans #' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) -#' @rdname fitted +#' @rdname spark.kstest #' @export -setGeneric("fitted") +setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp #' @export @@ -1340,13 +1375,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) -#' @rdname spark.survreg +#' @rdname spark.randomForest #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.randomForest", + function(data, formula, ...) { standardGeneric("spark.randomForest") }) -#' @rdname spark.lda +#' @rdname spark.survreg #' @export -setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) #' @rdname spark.lda #' @export @@ -1356,16 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) -#' @rdname spark.isoreg -#' @export -setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) - -#' @rdname spark.gaussianMixture -#' @export -setGeneric("spark.gaussianMixture", - function(data, formula, ...) { - standardGeneric("spark.gaussianMixture") - }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1373,11 +1399,3 @@ setGeneric("spark.gaussianMixture", #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) - -#' @rdname spark.als -#' @export -setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) - -#' @rdname spark.kstest -#' @export -setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b901307f8f409..7a220b8d53a2f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -95,6 +95,27 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) +#' S4 class that represents an LogisticRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel +#' @export +#' @note LogisticRegressionModel since 2.1.0 +setClass("LogisticRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -104,7 +125,8 @@ setClass("KSTest", representation(jobj = "jobj")) #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -117,7 +139,8 @@ NULL #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -647,6 +670,165 @@ setMethod("predict", signature(object = "KMeansModel"), predict_internal(object, newData) }) +#' Logistic Regression Model +#' +#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression +#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param regParam the regularization parameter. Default is 0.0. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. +#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination +#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param family the name of family which is a description of the label distribution to be used in the model. +#' Supported options: Default is "auto". +#' \itemize{ +#' \item{"auto": Automatically select the family based on the number of classes: +#' If number of classes == 1 || number of classes == 2, set to "binomial". +#' Else, set to "multinomial".} +#' \item{"binomial": Binary logistic regression with pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' } +#' @param standardization whether to standardize the training features before fitting the model. The coefficients +#' of models will be always returned on the original scale, so it will be transparent for +#' users. Note that with/without standardization, the models should be always converged +#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 +#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 +#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with +#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of +#' predicting each class. Array must have length equal to the number of classes, with values > 0, +#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p +#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' @param weightCol The weight column name. +#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions +#' are large, this param could be adjusted to a larger size. Default is 2. +#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' @param ... additional arguments passed to the method. +#' @return \code{spark.logit} returns a fitted logistic regression model +#' @rdname spark.logit +#' @aliases spark.logit,SparkDataFrame,formula-method +#' @name spark.logit +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' # binary logistic regression +#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) +#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) +#' binary_data <- as.data.frame(cbind(label, feature)) +#' binary_df <- createDataFrame(binary_data) +#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) +#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) +#' +#' # summary of binary logistic regression +#' blr_summary <- summary(blr_model) +#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(blr_model, path) +#' +#' # can also read back the saved model and predict +#' # Note that summary deos not work on loaded model +#' savedModel <- read.ml(path) +#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' +#' # multinomial logistic regression +#' +#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) +#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) +#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) +#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) +#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) +#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) +#' df <- createDataFrame(data) +#' +#' # Note that summary of multinomial logistic regression is not implemented yet +#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) +#' predict1 <- collect(select(predict(model, df), "prediction")) +#' } +#' @note spark.logit since 2.1.0 +setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, + tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, + probabilityCol = "probability") { + formula <- paste0(deparse(formula), collapse = "") + + if (is.null(weightCol)) { + weightCol <- "" + } + + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", + data@sdf, formula, as.numeric(regParam), + as.numeric(elasticNetParam), as.integer(maxIter), + as.numeric(tol), as.logical(fitIntercept), + as.character(family), as.logical(standardization), + as.array(thresholds), as.character(weightCol), + as.integer(aggregationDepth), as.character(probabilityCol)) + new("LogisticRegressionModel", jobj = jobj) + }) + +# Predicted values based on an LogisticRegressionModel model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. +#' @rdname spark.logit +#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(LogisticRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "LogisticRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Get the summary of an LogisticRegressionModel + +#' @param object an LogisticRegressionModel fitted by \code{spark.logit} +#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that +#' Multinomial logistic regression summary is not available now. +#' @rdname spark.logit +#' @aliases summary,LogisticRegressionModel-method +#' @export +#' @note summary(LogisticRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "LogisticRegressionModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + + if (is.loaded) { + stop("Loaded model doesn't have training summary.") + } + + roc <- dataFrame(callJMethod(jobj, "roc")) + + areaUnderROC <- callJMethod(jobj, "areaUnderROC") + + pr <- dataFrame(callJMethod(jobj, "pr")) + + fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) + + precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) + + recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) + + totalIterations <- callJMethod(jobj, "totalIterations") + + objectiveHistory <- callJMethod(jobj, "objectiveHistory") + + list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, + fMeasureByThreshold = fMeasureByThreshold, + precisionByThreshold = precisionByThreshold, + recallByThreshold = recallByThreshold, + totalIterations = totalIterations, objectiveHistory = objectiveHistory) + }) + #' Multilayer Perceptron Classification Model #' #' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. @@ -665,6 +847,8 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param tol convergence tolerance of iterations. #' @param stepSize stepSize parameter. #' @param seed seed parameter for weights initialization. +#' @param initialWeights initialWeights parameter for weights initialization, it should be a +#' numeric vector. #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -677,8 +861,9 @@ setMethod("predict", signature(object = "KMeansModel"), #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") #' #' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", -#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) +#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, +#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) #' #' # get the summary of the model #' summary(model) @@ -695,7 +880,7 @@ setMethod("predict", signature(object = "KMeansModel"), #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame"), function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL) { + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") } @@ -706,10 +891,13 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"), if (!is.null(seed)) { seed <- as.character(as.integer(seed)) } + if (!is.null(initialWeights)) { + initialWeights <- as.array(as.numeric(na.omit(initialWeights))) + } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed) + as.numeric(stepSize), seed, initialWeights) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) @@ -882,6 +1070,21 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char write_internal(object, path, overwrite) }) +# Save fitted LogisticRegressionModel to the input path + +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.logit +#' @aliases write.ml,LogisticRegressionModel,character-method +#' @export +#' @note write.ml(LogisticRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + # Save fitted MLlib model to the input path #' @param path the directory where the model is saved. @@ -932,6 +1135,12 @@ read.ml <- function(path) { new("GaussianMixtureModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { new("ALSModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { + new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1427,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini". (default = gini) +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. (default = 1.0) +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. (default = "probability") +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), as.character(probabilityCol), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(randomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(randomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of an RandomForestRegressionModel model +summary.randomForest <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +#' @return \code{summary} returns the model's features as lists, depth and number of nodes +#' or number of classes. +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Get the summary of an RandomForestClassificationModel model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of Random Forest Regression Model +print.summary.randomForest <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.randomForest(x) +} + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.randomForest(x) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc6d591bb2f4c..6b4a2f2fdc85c 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -154,6 +154,7 @@ sparkR.sparkContext <- function( packages <- processSparkPackages(sparkPackages) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) if (existingPort != "") { if (length(packages) != 0) { warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", @@ -187,6 +188,7 @@ sparkR.sparkContext <- function( backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) + connectionTimeout <- readInt(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || @@ -194,7 +196,9 @@ sparkR.sparkContext <- function( length(rLibPath) != 1) { stop("JVM failed to launch") } - assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".monitorConn", + socketConnection(port = monitorPort, timeout = connectionTimeout), + envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -204,7 +208,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort) + connectBackend("localhost", backendPort, timeout = connectionTimeout) }, error = function(err) { stop("Failed to connect JVM\n") diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fa8bb0f79ce80..c4e78cbb804d9 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -385,6 +385,47 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } +storageLevelToString <- function(levelObj) { + useDisk <- callJMethod(levelObj, "useDisk") + useMemory <- callJMethod(levelObj, "useMemory") + useOffHeap <- callJMethod(levelObj, "useOffHeap") + deserialized <- callJMethod(levelObj, "deserialized") + replication <- callJMethod(levelObj, "replication") + shortName <- if (!useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) { + "NONE" + } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) { + "DISK_ONLY" + } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 2) { + "DISK_ONLY_2" + } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 1) { + "MEMORY_ONLY" + } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 2) { + "MEMORY_ONLY_2" + } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) { + "MEMORY_ONLY_SER" + } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) { + "MEMORY_ONLY_SER_2" + } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 1) { + "MEMORY_AND_DISK" + } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 2) { + "MEMORY_AND_DISK_2" + } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) { + "MEMORY_AND_DISK_SER" + } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) { + "MEMORY_AND_DISK_SER_2" + } else if (useDisk && useMemory && useOffHeap && !deserialized && replication == 1) { + "OFF_HEAP" + } else { + NULL + } + fullInfo <- callJMethod(levelObj, "toString") + if (is.null(shortName)) { + fullInfo + } else { + paste(shortName, "-", fullInfo) + } +} + # Utility function for functions where an argument needs to be integer but we want to allow # the user to type (for example) `5` instead of `5L` to avoid a confusing error message. numToInt <- function(num) { diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index c99315726a22c..db98d0e45547e 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -410,6 +410,21 @@ test_that("spark.mlp", { model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) + + # test initialWeights + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) }) test_that("spark.naiveBayes", { @@ -587,6 +602,61 @@ test_that("spark.isotonicRegression", { unlink(modelPath) }) +test_that("spark.logit", { + # test binary logistic regression + label <- c(1.0, 1.0, 1.0, 0.0, 0.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + binary_data <- as.data.frame(cbind(label, feature)) + binary_df <- createDataFrame(binary_data) + + blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) + blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) + expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0)) + blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) + blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) + expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1)) + + # test summary of binary logistic regression + blr_summary <- summary(blr_model) + blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) + expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487), + tolerance = 1e-4) + expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000), + tolerance = 1e-4) + blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) + expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000), + tolerance = 1e-4) + blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) + expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000), + tolerance = 1e-4) + + # test model save and read + modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") + write.ml(blr_model, modelPath) + expect_error(write.ml(blr_model, modelPath)) + write.ml(blr_model, modelPath, overwrite = TRUE) + blr_model2 <- read.ml(modelPath) + blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) + expect_equal(blr_predict$prediction, blr_predict2$prediction) + expect_error(summary(blr_model2)) + unlink(modelPath) + + # test multinomial logistic regression + label <- c(0.0, 1.0, 2.0, 0.0, 0.0) + feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) + feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) + feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) + feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) + data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) + df <- createDataFrame(data) + + model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) + predict1 <- collect(select(predict(model, df), "prediction")) + expect_equal(predict1$prediction, c(0, 0, 0, 0, 0)) + # Summary of multinomial logistic regression is not implemented yet + expect_error(summary(model)) +}) + test_that("spark.gaussianMixture", { # R code to reproduce the result. # nolint start @@ -801,4 +871,72 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.randomForest Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, + 63.736, 64.296, 64.868, 64.300, + 66.709, 67.697, 67.966, 67.252, + 68.866, 69.593, 69.195, 69.658), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) +}) + +test_that("spark.randomForest Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index af81d0586e0a6..9289db57b6d63 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -212,7 +212,7 @@ test_that("createDataFrame uses files for large objects", { # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") - df <- createDataFrame(iris) + df <- suppressWarnings(createDataFrame(iris)) # Resetting the conf back to default value callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) @@ -390,6 +390,19 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17811: can create DataFrame containing NA as date and time", { + df <- data.frame( + id = 1:2, + time = c(as.POSIXlt("2016-01-10"), NA), + date = c(as.Date("2016-10-01"), NA)) + + DF <- collect(createDataFrame(df)) + expect_true(is.na(DF$date[2])) + expect_equal(DF$date[1], as.Date("2016-10-01")) + expect_true(is.na(DF$time[2])) + expect_equal(DF$time[1], as.POSIXlt("2016-01-10")) +}) + test_that("create DataFrame with complex types", { e <- new.env() assign("n", 3L, envir = e) @@ -783,7 +796,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v expect_false(collectRDD(second)[[3]]$testCol) }) -test_that("cache(), persist(), and unpersist() on a DataFrame", { +test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", { df <- read.json(jsonPath) expect_false(df@env$isCached) cache(df) @@ -795,6 +808,9 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { persist(df, "MEMORY_AND_DISK") expect_true(df@env$isCached) + expect_equal(storageLevel(df), + "MEMORY_AND_DISK - StorageLevel(disk, memory, deserialized, 1 replicas)") + unpersist(df) expect_false(df@env$isCached) @@ -832,7 +848,7 @@ test_that("names() colnames() set the column names", { expect_equal(names(df)[1], "col3") expect_error(colnames(df) <- c("sepal.length", "sepal_width"), - "Colum names cannot contain the '.' symbol.") + "Column names cannot contain the '.' symbol.") expect_error(colnames(df) <- c(1, 2), "Invalid column names.") expect_error(colnames(df) <- c("a"), "Column names must have the same length as the number of columns in the dataset.") @@ -1572,7 +1588,7 @@ test_that("filter() on a DataFrame", { #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) -test_that("join() and merge() on a DataFrame", { +test_that("join(), crossJoin() and merge() on a DataFrame", { df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -1583,7 +1599,14 @@ test_that("join() and merge() on a DataFrame", { writeLines(mockLines2, jsonPath2) df2 <- read.json(jsonPath2) - joined <- join(df, df2) + # inner join, not cartesian join + expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) + # cartesian join + expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), + paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", + " INNER join between logical plans).*")) + + joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_equal(count(joined), 12) expect_equal(names(collect(joined)), c("age", "name", "name", "test")) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index b92e6be995ca9..3a318b71ea06d 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,6 +18,7 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) +inputCon <- socketConnection( + port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index cfe41ded200c2..03e7450147865 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]] suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") -outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") +inputCon <- socketConnection( + port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) +outputCon <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/R/run-tests.sh b/R/run-tests.sh index 1a1e8ab9ffe18..5e4dafaf76f3d 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -26,6 +26,8 @@ rm -f $LOGFILE SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) +NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" + # Also run the documentation tests for CRAN CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out rm -f $CRAN_CHECK_LOG_FILE @@ -37,10 +39,10 @@ NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)" -if [[ $FAILED != 0 ]]; then +if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then cat $LOGFILE echo -en "\033[31m" # Red - echo "Had test failures; see logs." + echo "Had test warnings or failures; see logs." echo -en "\033[0m" # No color exit -1 else diff --git a/bin/pyspark b/bin/pyspark index 7590309b442ed..d6b3ab0a44321 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 1217a4f2f97a2..f211c0873ad2f 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.3-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c750c72d19880..5c1e876ef9afc 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -63,3 +63,4 @@ # - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp) # - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) # - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) +# - SPARK_NO_DAEMONIZE Run the proposed command in the foreground. It will not output a PID file. diff --git a/core/pom.xml b/core/pom.xml index 205bbc588be09..eac99ab82a2e4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,7 +331,7 @@ net.sf.py4j py4j - 0.10.3 + 0.10.4 org.apache.spark diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java new file mode 100644 index 0000000000000..f6d1288cb263d --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.apache.spark.storage.StorageUtils; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; + +/** + * {@link InputStream} implementation which uses direct buffer + * to read a file to avoid extra copy of data between Java and + * native memory which happens when using {@link java.io.BufferedInputStream}. + * Unfortunately, this is not something already available in JDK, + * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio, + * but does not support buffering. + */ +public final class NioBufferedFileInputStream extends InputStream { + + private static final int DEFAULT_BUFFER_SIZE_BYTES = 8192; + + private final ByteBuffer byteBuffer; + + private final FileChannel fileChannel; + + public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { + byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); + byteBuffer.flip(); + } + + public NioBufferedFileInputStream(File file) throws IOException { + this(file, DEFAULT_BUFFER_SIZE_BYTES); + } + + /** + * Checks weather data is left to be read from the input stream. + * @return true if data is left, false otherwise + * @throws IOException + */ + private boolean refill() throws IOException { + if (!byteBuffer.hasRemaining()) { + byteBuffer.clear(); + int nRead = 0; + while (nRead == 0) { + nRead = fileChannel.read(byteBuffer); + } + if (nRead < 0) { + return false; + } + byteBuffer.flip(); + } + return true; + } + + @Override + public synchronized int read() throws IOException { + if (!refill()) { + return -1; + } + return byteBuffer.get() & 0xFF; + } + + @Override + public synchronized int read(byte[] b, int offset, int len) throws IOException { + if (offset < 0 || len < 0 || offset + len < 0 || offset + len > b.length) { + throw new IndexOutOfBoundsException(); + } + if (!refill()) { + return -1; + } + len = Math.min(len, byteBuffer.remaining()); + byteBuffer.get(b, offset, len); + return len; + } + + @Override + public synchronized int available() throws IOException { + return byteBuffer.remaining(); + } + + @Override + public synchronized long skip(long n) throws IOException { + if (n <= 0L) { + return 0L; + } + if (byteBuffer.remaining() >= n) { + // The buffered content is enough to skip + byteBuffer.position(byteBuffer.position() + (int) n); + return n; + } + long skippedFromBuffer = byteBuffer.remaining(); + long toSkipFromFileChannel = n - skippedFromBuffer; + // Discard everything we have read in the buffer. + byteBuffer.position(0); + byteBuffer.flip(); + return skippedFromBuffer + skipFromFileChannel(toSkipFromFileChannel); + } + + private long skipFromFileChannel(long n) throws IOException { + long currentFilePosition = fileChannel.position(); + long size = fileChannel.size(); + if (n > size - currentFilePosition) { + fileChannel.position(size); + return size - currentFilePosition; + } else { + fileChannel.position(currentFilePosition + n); + return n; + } + } + + @Override + public synchronized void close() throws IOException { + fileChannel.close(); + StorageUtils.dispose(byteBuffer); + } + + @Override + protected void finalize() throws IOException { + close(); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 7835017910232..dcae4a34c4b0b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -143,9 +143,10 @@ private UnsafeExternalSorter( this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, and then discarded (this fixes SPARK-16827). + // The spill metrics are stored in a new ShuffleWriteMetrics, + // and then discarded (this fixes SPARK-16827). // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). this.writeMetrics = new ShuffleWriteMetrics(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e6d9766c31574..a658e5eb47b78 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,6 +23,7 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; +import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; @@ -69,8 +70,8 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final BufferedInputStream bs = - new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes); + final InputStream bs = + new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { this.in = serializerManager.wrapStream(blockId, bs); this.din = new DataInputStream(this.in); diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index a6153ceda75e2..705a08f0293d3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -24,6 +24,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime, offset) { return a.value - b.value }, editable: false, + align: 'left', showCurrentTime: false, min: startTime, zoomable: false, @@ -99,6 +100,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime, offset) { return a.value - b.value; }, editable: false, + align: 'left', showCurrentTime: false, min: startTime, zoomable: false, diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 701097ace8974..c4e55b5e89027 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.3-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 34cb7c61d7034..86965dbc2e778 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -144,7 +144,7 @@ object WriteInputFormatTestDataGenerator { // Create test data for ArrayWritable val data = Seq( - (1, Array()), + (1, Array.empty[Double]), (2, Array(3.0, 4.0, 5.0)), (3, Array(4.0, 5.0, 6.0)) ) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 41d0a85ee3ad4..550746c552d02 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -43,7 +44,10 @@ private[spark] class RBackend { def init(): Int = { val conf = new SparkConf() - bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup( + conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) @@ -63,6 +67,7 @@ private[spark] class RBackend { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) + .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) .addLast("handler", handler) } }) @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging { val boundPort = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() + // Connection timeout is set by socket client. To make it configurable we will pass the + // timeout value to client inside the temp file + val conf = new SparkConf() + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(boundPort) dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) + dos.writeInt(backendConnectionTimeout) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 1422ef888fd4a..9f5afa29d6d22 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -18,16 +18,19 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** * Handler for RBackend @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend) writeString(dos, s"Error: unknown method $methodName") } } else { + // To avoid timeouts when reading results in SparkR driver, we will be regularly sending + // heartbeat responses. We use special code +1 to signal the client that backend is + // alive and it should continue blocking for result. + val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread") + val pingRunner = new Runnable { + override def run(): Unit = { + val pingBaos = new ByteArrayOutputStream() + val pingDaos = new DataOutputStream(pingBaos) + writeInt(pingDaos, +1) + ctx.write(pingBaos.toByteArray) + } + } + val conf = new SparkConf() + val heartBeatInterval = conf.getInt( + "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) + + execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + execService.shutdown() + execService.awaitTermination(1, TimeUnit.SECONDS) } val reply = bos.toByteArray @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend) } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Close the connection when an exception is raised. - cause.printStackTrace() - ctx.close() + cause match { + case timeout: ReadTimeoutException => + // Do nothing. We don't want to timeout on read + logWarning("Ignoring read timeout in RBackendHandler") + case _ => + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } } def handleMethodCall( diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 496fdf851f7db..7ef64723d9593 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -333,6 +333,8 @@ private[r] object RRunner { var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") rCommand = sparkConf.get("spark.r.command", rCommand) + val rConnectionTimeout = sparkConf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script @@ -344,6 +346,7 @@ private[r] object RRunner { pb.environment().put("R_TESTS", "") pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index e4932a4192d39..550e075a95129 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -125,15 +125,34 @@ private[spark] object SerDe { } def readDate(in: DataInputStream): Date = { - Date.valueOf(readString(in)) + try { + val inStr = readString(in) + if (inStr == "NA") { + null + } else { + Date.valueOf(inStr) + } + } catch { + // TODO: SPARK-18011 with some versions of R deserializing NA from R results in NASE + case _: NegativeArraySizeException => null + } } def readTime(in: DataInputStream): Timestamp = { - val seconds = in.readDouble() - val sec = Math.floor(seconds).toLong - val t = new Timestamp(sec * 1000L) - t.setNanos(((seconds - sec) * 1e9).toInt) - t + try { + val seconds = in.readDouble() + if (java.lang.Double.isNaN(seconds)) { + null + } else { + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + } catch { + // TODO: SPARK-18011 with some versions of R deserializing NA from R results in NASE + case _: NegativeArraySizeException => null + } } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala new file mode 100644 index 0000000000000..af67cbbce4e51 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +private[spark] object SparkRDefaults { + + // Default value for spark.r.backendConnectionTimeout config + val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + + // Default value for spark.r.heartBeatInterval config + val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + + // Default value for spark.r.numRBackendThreads config + val DEFAULT_NUM_RBACKEND_THREADS = 2 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index d0466830b2177..6eb53a8252205 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} import org.apache.spark.util.RedirectThread /** @@ -51,6 +51,10 @@ object RRunner { cmd } + // Connection timeout set by R process on its connection to RBackend in seconds. + val backendConnectionTimeout = sys.props.getOrElse( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode val rF = new File(rFile) @@ -81,6 +85,7 @@ object RRunner { val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) // Put the R package directories into an env variable of comma-separated paths env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index ad7a0972ef9d1..06530ff836466 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy.history import java.util.zip.ZipOutputStream +import scala.xml.Node + import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI @@ -114,4 +116,8 @@ private[history] abstract class ApplicationHistoryProvider { */ def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] + /** + * @return html text to display when the application list is empty + */ + def getEmptyListingHtml(): Seq[Node] = Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 3c2d169f3270e..dfc1aad64c818 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{Executors, ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable +import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -35,6 +36,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -77,10 +79,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) import FsHistoryProvider._ - private val NOT_STARTED = "" - - private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" - // Interval between safemode checks. private val SAFEMODE_CHECK_INTERVAL_S = conf.getTimeAsSeconds( "spark.history.fs.safemodeCheck.interval", "5s") @@ -240,11 +238,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() - replayBus.addListener(appListener) - val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), - replayBus) - appAttemptInfo.map { info => + + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + + val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + + if (appListener.appId.isDefined) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up @@ -253,8 +252,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.viewAcls.getOrElse("")) ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize)) + Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) + } else { + None } + } } } catch { @@ -262,6 +264,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + override def getEmptyListingHtml(): Seq[Node] = { +

+ Did you specify the correct logging directory? Please verify your setting of + spark.history.fs.logDirectory + listed above and whether you have the permissions to access it. +
+ It is also possible that your application did not run to + completion or did not stop the SparkContext. +

+ } + override def getConfig(): Map[String, String] = { val safeMode = if (isFsInSafeMode()) { Map("HDFS State" -> "In safe mode, application logs not available.") @@ -399,28 +412,54 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { - val bus = new ReplayListenerBus() - val res = replay(fileStatus, bus) - res match { - case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully: $r") - case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + - "The application may have not started.") - } - res - } catch { - case e: Exception => - logError( - s"Exception encountered when attempting to load application log ${fileStatus.getPath}", - e) - None + val eventsFilter: ReplayEventsFilter = { eventString => + eventString.startsWith(APPL_START_EVENT_PREFIX) || + eventString.startsWith(APPL_END_EVENT_PREFIX) + } + + val logPath = fileStatus.getPath() + + val appCompleted = isApplicationCompleted(fileStatus) + + val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. + if (appListener.appId.isDefined) { + val attemptInfo = new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + fileStatus.getModificationTime(), + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted, + fileStatus.getLen() + ) + fileToAppInfo(logPath) = attemptInfo + logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") + Some(attemptInfo) + } else { + logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + None } + } catch { + case e: Exception => + logError( + s"Exception encountered when attempting to load application log ${fileStatus.getPath}", + e) + None + } + if (newAttempts.isEmpty) { return } @@ -552,12 +591,16 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Replays the events in the specified log file and returns information about the associated - * application. Return `None` if the application ID cannot be located. + * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns + * an `ApplicationEventListener` instance with event data captured from the replay. + * `ReplayEventsFilter` determines what events are replayed and can therefore limit the + * data captured in the returned `ApplicationEventListener` instance. */ private def replay( eventLog: FileStatus, - bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { + appCompleted: Boolean, + bus: ReplayListenerBus, + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): ApplicationEventListener = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, @@ -569,30 +612,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener - val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString, !appCompleted) - - // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. - if (appListener.appId.isDefined) { - val attemptInfo = new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - eventLog.getModificationTime(), - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted, - eventLog.getLen() - ) - fileToAppInfo(logPath) = attemptInfo - Some(attemptInfo) - } else { - None - } + bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) + appListener } finally { logInput.close() } @@ -677,6 +699,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + private val NOT_STARTED = "" + + private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" + + private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" + + private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 95b72224e0f94..96b9ecf43b14c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -47,13 +47,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

} else { -

No completed applications found!

++ -

Did you specify the correct logging directory? - Please verify your setting of - spark.history.fs.logDirectory and whether you have the permissions to - access it.
It is also possible that your application did not run to - completion or did not stop the SparkContext. -

+

No completed applications found!

++ parent.emptyListingHtml } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 087c69e6489dd..3175b36b3e56f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -22,6 +22,7 @@ import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.util.control.NonFatal +import scala.xml.Node import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} @@ -193,6 +194,13 @@ class HistoryServer( provider.writeEventLogs(appId, attemptId, zipStream) } + /** + * @return html text to display when the application list is empty + */ + def emptyListingHtml(): Seq[Node] = { + provider.getEmptyListingHtml() + } + /** * Returns the provider configuration to show in the listing page. * diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 289b0b93b0e84..e878c10183f61 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -18,12 +18,12 @@ package org.apache.spark.deploy.worker import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import com.google.common.io.Files -import org.apache.hadoop.fs.Path import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} @@ -147,30 +147,24 @@ private[deploy] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) - val jarFileName = jarPath.getName + val jarFileName = new URI(driverDesc.jarUrl).getPath.split("/").last val localJarFile = new File(driverDir, jarFileName) - val localJarFilename = localJarFile.getAbsolutePath - if (!localJarFile.exists()) { // May already exist if running multiple workers on one node - logInfo(s"Copying user jar $jarPath to $destPath") + logInfo(s"Copying user jar ${driverDesc.jarUrl} to $localJarFile") Utils.fetchFile( driverDesc.jarUrl, driverDir, conf, securityManager, - hadoopConf, + SparkHadoopUtil.get.newConfiguration(conf), System.currentTimeMillis(), useCache = false) + if (!localJarFile.exists()) { // Verify copy succeeded + throw new IOException( + s"Can not find expected jar $jarFileName which should have been loaded in $driverDir") + } } - - if (!localJarFile.exists()) { // Verify copy succeeded - throw new Exception(s"Did not see expected jar $jarFileName in $driverDir") - } - - localJarFilename + localJarFile.getAbsolutePath } private[worker] def prepareAndRunDriver(): Int = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 3473c41b935fd..465c214362b25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -138,7 +140,8 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") - val totalLength = files.map { _.length }.sum + val fileLengths: Seq[Long] = files.map(Utils.getFileLength(_, worker.conf)) + val totalLength = fileLengths.sum val offset = offsetOption.getOrElse(totalLength - byteLength) val startIndex = { if (offset < 0) { @@ -151,7 +154,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") - val logText = Utils.offsetBytes(files, startIndex, endIndex) + val logText = Utils.offsetBytes(files, fileLengths, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") (logText, startIndex, endIndex, totalLength) } catch { diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 6bba259acc391..3f7cfd9d2c11f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -26,7 +26,7 @@ private[spark] object StaticSources { * The set of all static sources. These sources may be reported to from any class, including * static classes, without requiring reference to a SparkEnv. */ - val allSources = Seq(CodegenMetrics) + val allSources = Seq(CodegenMetrics, HiveCatalogMetrics) } /** @@ -60,3 +60,42 @@ object CodegenMetrics extends Source { val METRIC_GENERATED_METHOD_BYTECODE_SIZE = metricRegistry.histogram(MetricRegistry.name("generatedMethodSize")) } + +/** + * :: Experimental :: + * Metrics for access to the hive external catalog. + */ +@Experimental +object HiveCatalogMetrics extends Source { + override val sourceName: String = "HiveExternalCatalog" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + /** + * Tracks the total number of partition metadata entries fetched via the client api. + */ + val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched")) + + /** + * Tracks the total number of files discovered off of the filesystem by InMemoryFileIndex. + */ + val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered")) + + /** + * Tracks the total number of files served from the file status cache instead of discovered. + */ + val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + + /** + * Resets the values of all metrics to zero. This is useful in tests. + */ + def reset(): Unit = { + METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) + METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) + METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + } + + // clients can use these to avoid classloader issues with the codahale classes + def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) + def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) + def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 9c198a61f37af..2cba1febe8759 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -80,6 +80,10 @@ private[spark] class CoalescedRDD[T: ClassTag]( require(maxPartitions > 0 || maxPartitions == prev.partitions.length, s"Number of partitions ($maxPartitions) must be positive.") + if (partitionCoalescer.isDefined) { + require(partitionCoalescer.get.isInstanceOf[Serializable], + "The partition coalescer passed in must be serializable.") + } override def getPartitions: Array[Partition] = { val pc = partitionCoalescer.getOrElse(new DefaultPartitionCoalescer()) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6dc334ceb52ea..db535de9e9bb3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -432,7 +432,8 @@ abstract class RDD[T: ClassTag]( * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the - * data distributed using a hash partitioner. + * data distributed using a hash partitioner. The optional partition coalescer + * passed in must be serializable. */ def coalesce(numPartitions: Int, shuffle: Boolean = false, partitionCoalescer: Option[PartitionCoalescer] = Option.empty) @@ -1278,7 +1279,7 @@ abstract class RDD[T: ClassTag]( def zipWithUniqueId(): RDD[(T, Long)] = withScope { val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => - iter.zipWithIndex.map { case (item, i) => + Utils.getIteratorZipWithIndex(iter, 0L).map { case (item, i) => (item, i * n + k) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b5738b9a95c36..b0e5ba0865c63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -64,8 +64,7 @@ class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = { val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition] - firstParent[T].iterator(split.prev, context).zipWithIndex.map { x => - (x._1, split.startIndex + x._2) - } + val parentIter = firstParent[T].iterator(split.prev, context) + Utils.getIteratorZipWithIndex(parentIter, split.startIndex) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index d32f5eb7bfe92..2424586431aa0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -25,6 +25,7 @@ import com.fasterxml.jackson.core.JsonParseException import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.util.JsonProtocol /** @@ -43,30 +44,49 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * @param sourceName Filename (or other source identifier) from whence @logData is being read * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not + * @param eventsFilter Filter function to select JSON event strings in the log data stream that + * should be parsed and replayed. When not specified, all event strings in the log data + * are parsed and replayed. */ def replay( logData: InputStream, sourceName: String, - maybeTruncated: Boolean = false): Unit = { + maybeTruncated: Boolean = false, + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + var currentLine: String = null - var lineNumber: Int = 1 + var lineNumber: Int = 0 + try { - val lines = Source.fromInputStream(logData).getLines() - while (lines.hasNext) { - currentLine = lines.next() + val lineEntries = Source.fromInputStream(logData) + .getLines() + .zipWithIndex + .filter { case (line, _) => eventsFilter(line) } + + while (lineEntries.hasNext) { try { + val entry = lineEntries.next() + + currentLine = entry._1 + lineNumber = entry._2 + 1 + postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) } catch { + case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) => + // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated - if (!maybeTruncated || lines.hasNext) { + // the last entry may not be the very last line in the event log, but we treat it + // as such in a best effort to replay the given input + if (!maybeTruncated || lineEntries.hasNext) { throw jpe } else { logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } - lineNumber += 1 } } catch { case ioe: IOException => @@ -78,3 +98,21 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + + +private[spark] object ReplayListenerBus { + + type ReplayEventsFilter = (String) => Boolean + + // utility filter that selects all event logs during replay + val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true } + + /** + * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing + * old json may fail and we can just ignore these failures. + */ + val KNOWN_REMOVED_CLASSES = Set( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress", + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated" + ) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 77fda6fcff959..366b92c5f2ada 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockId import org.apache.spark.util.{AccumulatorV2, Utils} @@ -77,14 +78,14 @@ private[spark] class DirectTaskResult[T]( * * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. */ - def value(): T = { + def value(resultSer: SerializerInstance = null): T = { if (valueObjectDeserialized) { valueObject } else { // This should not run when holding a lock because it may cost dozens of seconds for a large - // value. - val resultSer = SparkEnv.get.serializer.newInstance() - valueObject = resultSer.deserialize(valueBytes) + // value + val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer + valueObject = ser.deserialize(valueBytes) valueObjectDeserialized = true valueObject } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 1c3fcbd4612a0..b1addc128e696 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -48,6 +48,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } + protected val taskResultSerializer = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = { + sparkEnv.serializer.newInstance() + } + } + def enqueueSuccessfulTask( taskSetManager: TaskSetManager, tid: Long, @@ -63,7 +69,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // deserialize "value" without holding any lock so that it won't block other threads. // We should call it here, so that when it's called again in // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. - directResult.value() + directResult.value(taskResultSerializer.get()) (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { @@ -84,6 +90,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get.toByteBuffer) + // force deserialization of referenced value + deserializedResult.value(taskResultSerializer.get()) sparkEnv.blockManager.master.removeBlock(blockId) (deserializedResult, size) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 9491bc7a0497e..b766e4148e496 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,7 +79,7 @@ private[spark] class TaskSetManager( var minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId - var name = "TaskSet_" + taskSet.stageId.toString + val name = "TaskSet_" + taskSet.id var parent: Pool = null var totalResultSize = 0L var calculatedTasks = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0dae0e614e17d..10d55c87fb8de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -386,15 +386,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only * be called in the yarn-client mode when AM re-registers after a failure. * */ - protected def reset(): Unit = synchronized { - numPendingExecutors = 0 - executorsPendingToRemove.clear() + protected def reset(): Unit = { + val executors = synchronized { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + Set() ++ executorDataMap.keys + } // Remove all the lingering executors that should be removed but not yet. The reason might be // because (1) disconnected event is not yet received; (2) executors die silently. - executorDataMap.toMap.foreach { case (eid, _) => - driverEndpoint.askWithRetry[Boolean]( - RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + executors.foreach { eid => + removeExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1fba552f70501..0d26281fe1076 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,6 +27,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} @@ -78,8 +79,15 @@ class KryoSerializer(conf: SparkConf) .filter(!_.isEmpty) private val avroSchemas = conf.getAvroSchema + // whether to use unsafe based IO for serialization + private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false) - def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoOutput(): KryoOutput = + if (useUnsafe) { + new KryoUnsafeOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + } else { + new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + } def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator @@ -172,7 +180,7 @@ class KryoSerializer(conf: SparkConf) } override def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this) + new KryoSerializerInstance(this, useUnsafe) } private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { @@ -186,9 +194,12 @@ class KryoSerializer(conf: SparkConf) private[spark] class KryoSerializationStream( serInstance: KryoSerializerInstance, - outStream: OutputStream) extends SerializationStream { + outStream: OutputStream, + useUnsafe: Boolean) extends SerializationStream { + + private[this] var output: KryoOutput = + if (useUnsafe) new KryoUnsafeOutput(outStream) else new KryoOutput(outStream) - private[this] var output: KryoOutput = new KryoOutput(outStream) private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { @@ -219,9 +230,12 @@ class KryoSerializationStream( private[spark] class KryoDeserializationStream( serInstance: KryoSerializerInstance, - inStream: InputStream) extends DeserializationStream { + inStream: InputStream, + useUnsafe: Boolean) extends DeserializationStream { + + private[this] var input: KryoInput = + if (useUnsafe) new KryoUnsafeInput(inStream) else new KryoInput(inStream) - private[this] var input: KryoInput = new KryoInput(inStream) private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { @@ -248,8 +262,8 @@ class KryoDeserializationStream( } } -private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - +private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean) + extends SerializerInstance { /** * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching @@ -288,7 +302,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() - private lazy val input = new KryoInput() + private lazy val input = if (useUnsafe) new KryoUnsafeInput() else new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() @@ -329,11 +343,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } override def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(this, s) + new KryoSerializationStream(this, s, useUnsafe) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(this, s) + new KryoDeserializationStream(this, s, useUnsafe) } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 8d6396bededa9..91858f0912b65 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging +import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID @@ -89,7 +90,7 @@ private[spark] class IndexShuffleBlockResolver( val lengths = new Array[Long](blocks) // Read the lengths of blocks val in = try { - new DataInputStream(new BufferedInputStream(new FileInputStream(index))) + new DataInputStream(new NioBufferedFileInputStream(index)) } catch { case e: IOException => return null diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index ef71db89798f1..f631a047a707d 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -58,14 +58,13 @@ private[spark] class SparkUI private ( val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) - - val stagesTab = new StagesTab(this) - var appId: String = _ /** Initialize all components of the server. */ def initialize() { - attachTab(new JobsTab(this)) + val jobsTab = new JobsTab(this) + attachTab(jobsTab) + val stagesTab = new StagesTab(this) attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) @@ -73,7 +72,9 @@ private[spark] class SparkUI private ( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) - // This should be POST only, but, the YARN AM proxy won't proxy POSTs + // These should be POST only, but, the YARN AM proxy won't proxy POSTs + attachHandler(createRedirectHandler( + "/jobs/job/kill", "/jobs/", jobsTab.handleKillRequest, httpMethods = Set("GET", "POST"))) attachHandler(createRedirectHandler( "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 4118fcf46b428..a05e0efb7a3e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -147,7 +147,10 @@ private[spark] abstract class WebUI( } /** Return the url of web interface. Only valid after bind(). */ - def webUrl: String = s"http://$publicHostName:$boundPort" + def webUrl: String = { + val protocol = if (sslOptions.enabled) "https" else "http" + s"$protocol://$publicHostName:$boundPort" + } /** Return the actual port to which this server is bound. Only valid after bind(). */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f6713097b9349..173fc3cf31ce8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -218,7 +218,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { request: HttpServletRequest, tableHeaderId: String, jobTag: String, - jobs: Seq[JobUIData]): Seq[Node] = { + jobs: Seq[JobUIData], + killEnabled: Boolean): Seq[Node] = { val allParameters = request.getParameterMap.asScala.toMap val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) @@ -264,6 +265,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { parameterOtherTable, parent.jobProgresslistener.stageIdToInfo, parent.jobProgresslistener.stageIdToData, + killEnabled, currentTime, jobIdTitle, pageSize = jobPageSize, @@ -290,9 +292,12 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val completedJobs = listener.completedJobs.reverse.toSeq val failedJobs = listener.failedJobs.reverse.toSeq - val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs) - val completedJobsTable = jobsTable(request, "completed", "completedJob", completedJobs) - val failedJobsTable = jobsTable(request, "failed", "failedJob", failedJobs) + val activeJobsTable = + jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) + val completedJobsTable = + jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) + val failedJobsTable = + jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty @@ -483,6 +488,7 @@ private[ui] class JobPagedTable( parameterOtherTable: Iterable[String], stageIdToInfo: HashMap[Int, StageInfo], stageIdToData: HashMap[(Int, Int), StageUIData], + killEnabled: Boolean, currentTime: Long, jobIdTitle: String, pageSize: Int, @@ -586,12 +592,30 @@ private[ui] class JobPagedTable( override def row(jobTableRow: JobTableRowData): Seq[Node] = { val job = jobTableRow.jobData + val killLink = if (killEnabled) { + val confirm = + s"if (window.confirm('Are you sure you want to kill job ${job.jobId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" + // SPARK-6846 this should be POST-only but YARN AM won't proxy POST + /* + val killLinkUri = s"$basePathUri/jobs/job/kill/" +
+ + (kill) +
+ */ + val killLinkUri = s"$basePath/jobs/job/kill/?id=${job.jobId}" + (kill) + } else { + Seq.empty + } + {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} - {jobTableRow.jobDescription} + {jobTableRow.jobDescription} {killLink} {jobTableRow.lastStageName} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 7b00b558d591a..620c54c2dc0a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import javax.servlet.http.HttpServletRequest + import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab} @@ -35,4 +37,19 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { attachPage(new AllJobsPage(this)) attachPage(new JobPage(this)) + + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { + val jobId = Option(request.getParameter("id")).map(_.toInt) + jobId.foreach { id => + if (jobProgresslistener.activeJobs.contains(id)) { + sc.foreach(_.cancelJob(id)) + // Do a quick pause here to give Spark time to kill the job so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9b9b4681ba5db..c9d0431e2d2f7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -353,12 +353,13 @@ private[ui] class StagePagedTable( val killLinkUri = s"$basePathUri/stages/stage/kill/"
- (kill)
*/ - val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}&terminate=true" + val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}" (kill) + } else { + Seq.empty } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 573192ac17d45..c1f25114371f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -39,15 +39,16 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt - if (stageId >= 0 && killFlag && progressListener.activeStages.contains(stageId)) { - sc.get.cancelStage(stageId) + val stageId = Option(request.getParameter("id")).map(_.toInt) + stageId.foreach { id => + if (progressListener.activeStages.contains(id)) { + sc.foreach(_.cancelStage(id)) + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } - // Do a quick pause here to give Spark time to kill the stage so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) } } diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index 91a95871014f0..e7a65d74a440e 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -26,8 +26,6 @@ package org.apache.spark.util */ private[spark] class ManualClock(private var time: Long) extends Clock { - private var _isWaiting = false - /** * @return `ManualClock` with initial time 0 */ @@ -59,19 +57,9 @@ private[spark] class ManualClock(private var time: Long) extends Clock { * @return current time reported by the clock when waiting finishes */ def waitTillTime(targetTime: Long): Long = synchronized { - _isWaiting = true - try { - while (time < targetTime) { - wait(10) - } - getTimeMillis() - } finally { - _isWaiting = false + while (time < targetTime) { + wait(10) } + getTimeMillis() } - - /** - * Returns whether there is any thread being blocked in `waitTillTime`. - */ - def isWaiting: Boolean = synchronized { _isWaiting } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 5a6dbc830448a..d093e7bfc3dac 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -194,4 +194,25 @@ private[spark] object ThreadUtils { throw new SparkException("Exception thrown in awaitResult: ", t) } } + + /** + * Calls [[Awaitable.result]] directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps + * and re-throws any exceptions with nice stack track. + * + * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent + * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method + * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + */ + @throws(classOf[SparkException]) + def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(Duration.Inf)(awaitPermission) + } catch { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ef832756ce3b7..6027b07c0fee8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -27,6 +27,7 @@ import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.GZIPInputStream import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec @@ -38,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -55,6 +57,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.util.logging.RollingFileAppender /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -1440,14 +1443,72 @@ private[spark] object Utils extends Logging { CallSite(shortForm, longForm) } + private val UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF = + "spark.worker.ui.compressedLogFileLengthCacheSize" + private val DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE = 100 + private var compressedLogFileLengthCache: LoadingCache[String, java.lang.Long] = null + private def getCompressedLogFileLengthCache( + sparkConf: SparkConf): LoadingCache[String, java.lang.Long] = this.synchronized { + if (compressedLogFileLengthCache == null) { + val compressedLogFileLengthCacheSize = sparkConf.getInt( + UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF, + DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE) + compressedLogFileLengthCache = CacheBuilder.newBuilder() + .maximumSize(compressedLogFileLengthCacheSize) + .build[String, java.lang.Long](new CacheLoader[String, java.lang.Long]() { + override def load(path: String): java.lang.Long = { + Utils.getCompressedFileLength(new File(path)) + } + }) + } + compressedLogFileLengthCache + } + + /** + * Return the file length, if the file is compressed it returns the uncompressed file length. + * It also caches the uncompressed file size to avoid repeated decompression. The cache size is + * read from workerConf. + */ + def getFileLength(file: File, workConf: SparkConf): Long = { + if (file.getName.endsWith(".gz")) { + getCompressedLogFileLengthCache(workConf).get(file.getAbsolutePath) + } else { + file.length + } + } + + /** Return uncompressed file length of a compressed file. */ + private def getCompressedFileLength(file: File): Long = { + try { + // Uncompress .gz file to determine file size. + var fileSize = 0L + val gzInputStream = new GZIPInputStream(new FileInputStream(file)) + val bufSize = 1024 + val buf = new Array[Byte](bufSize) + var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) + while (numBytes > 0) { + fileSize += numBytes + numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) + } + fileSize + } catch { + case e: Throwable => + logError(s"Cannot get file length of ${file}", e) + throw e + } + } + /** Return a string containing part of a file from byte 'start' to 'end'. */ - def offsetBytes(path: String, start: Long, end: Long): String = { + def offsetBytes(path: String, length: Long, start: Long, end: Long): String = { val file = new File(path) - val length = file.length() val effectiveEnd = math.min(length, end) val effectiveStart = math.max(0, start) val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) - val stream = new FileInputStream(file) + val stream = if (path.endsWith(".gz")) { + new GZIPInputStream(new FileInputStream(file)) + } else { + new FileInputStream(file) + } try { ByteStreams.skipFully(stream, effectiveStart) @@ -1463,8 +1524,8 @@ private[spark] object Utils extends Logging { * and `endIndex` is based on the cumulative size of all the files take in * the given order. See figure below for more details. */ - def offsetBytes(files: Seq[File], start: Long, end: Long): String = { - val fileLengths = files.map { _.length } + def offsetBytes(files: Seq[File], fileLengths: Seq[Long], start: Long, end: Long): String = { + assert(files.length == fileLengths.length) val startIndex = math.max(start, 0) val endIndex = math.min(end, fileLengths.sum) val fileToLength = files.zip(fileLengths).toMap @@ -1472,7 +1533,7 @@ private[spark] object Utils extends Logging { val stringBuffer = new StringBuffer((endIndex - startIndex).toInt) var sum = 0L - for (file <- files) { + files.zip(fileLengths).foreach { case (file, fileLength) => val startIndexOfFile = sum val endIndexOfFile = sum + fileToLength(file) logDebug(s"Processing file $file, " + @@ -1491,19 +1552,19 @@ private[spark] object Utils extends Logging { if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) { // Case C: read the whole file - stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file))) + stringBuffer.append(offsetBytes(file.getAbsolutePath, fileLength, 0, fileToLength(file))) } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) { // Case A and B: read from [start of required range] to [end of file / end of range] val effectiveStartIndex = startIndex - startIndexOfFile val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file)) stringBuffer.append(Utils.offsetBytes( - file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) { // Case D: read from [start of file] to [end of require range] val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0) val effectiveEndIndex = endIndex - startIndexOfFile stringBuffer.append(Utils.offsetBytes( - file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") @@ -1698,6 +1759,22 @@ private[spark] object Utils extends Logging { count } + /** + * Generate a zipWithIndex iterator, avoid index value overflowing problem + * in scala's zipWithIndex + */ + def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = { + new Iterator[(T, Long)] { + require(startIndex >= 0, "startIndex should be >= 0.") + var index: Long = startIndex - 1L + def hasNext: Boolean = iterator.hasNext + def next(): (T, Long) = { + index += 1L + (iterator.next(), index) + } + } + } + /** * Creates a symlink. * @@ -2432,6 +2509,26 @@ private[spark] object Utils extends Logging { } } +private[util] object CallerContext extends Logging { + val callerContextSupported: Boolean = { + SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { + try { + // scalastyle:off classforname + Class.forName("org.apache.hadoop.ipc.CallerContext") + Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname + true + } catch { + case _: ClassNotFoundException => + false + case NonFatal(e) => + logWarning("Fail to load the CallerContext class", e) + false + } + } + } +} + /** * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be * constructed by parameters passed in. @@ -2478,21 +2575,21 @@ private[spark] class CallerContext( * Set up the caller context [[context]] by invoking Hadoop CallerContext API of * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8. */ - def setCurrentContext(): Boolean = { - var succeed = false - try { - // scalastyle:off classforname - val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") - val Builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname - val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) - val hdfsContext = Builder.getMethod("build").invoke(builderInst) - callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) - succeed = true - } catch { - case NonFatal(e) => logInfo("Fail to set Spark caller context", e) + def setCurrentContext(): Unit = { + if (CallerContext.callerContextSupported) { + try { + // scalastyle:off classforname + val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") + val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname + val builderInst = builder.getConstructor(classOf[String]).newInstance(context) + val hdfsContext = builder.getMethod("build").invoke(builderInst) + callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) + } catch { + case NonFatal(e) => + logWarning("Fail to set Spark caller context", e) + } } - succeed } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index a0eb05c7c0e82..5d8cec8447b53 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -17,9 +17,11 @@ package org.apache.spark.util.logging -import java.io.{File, FileFilter, InputStream} +import java.io._ +import java.util.zip.GZIPOutputStream import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.spark.SparkConf @@ -45,6 +47,7 @@ private[spark] class RollingFileAppender( import RollingFileAppender._ private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) + private val enableCompression = conf.getBoolean(ENABLE_COMPRESSION, false) /** Stop the appender */ override def stop() { @@ -76,6 +79,33 @@ private[spark] class RollingFileAppender( } } + // Roll the log file and compress if enableCompression is true. + private def rotateFile(activeFile: File, rolloverFile: File): Unit = { + if (enableCompression) { + val gzFile = new File(rolloverFile.getAbsolutePath + GZIP_LOG_SUFFIX) + var gzOutputStream: GZIPOutputStream = null + var inputStream: InputStream = null + try { + inputStream = new FileInputStream(activeFile) + gzOutputStream = new GZIPOutputStream(new FileOutputStream(gzFile)) + IOUtils.copy(inputStream, gzOutputStream) + inputStream.close() + gzOutputStream.close() + activeFile.delete() + } finally { + IOUtils.closeQuietly(inputStream) + IOUtils.closeQuietly(gzOutputStream) + } + } else { + Files.move(activeFile, rolloverFile) + } + } + + // Check if the rollover file already exists. + private def rolloverFileExist(file: File): Boolean = { + file.exists || new File(file.getAbsolutePath + GZIP_LOG_SUFFIX).exists + } + /** Move the active log file to a new rollover file */ private def moveFile() { val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() @@ -83,8 +113,8 @@ private[spark] class RollingFileAppender( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) + if (!rolloverFileExist(rolloverFile)) { + rotateFile(activeFile, rolloverFile) logInfo(s"Rolled over $activeFile to $rolloverFile") } else { // In case the rollover file name clashes, make a unique file name. @@ -97,11 +127,11 @@ private[spark] class RollingFileAppender( altRolloverFile = new File(activeFile.getParent, s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile i += 1 - } while (i < 10000 && altRolloverFile.exists) + } while (i < 10000 && rolloverFileExist(altRolloverFile)) logWarning(s"Rollover file $rolloverFile already exists, " + s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) + rotateFile(activeFile, altRolloverFile) } } else { logWarning(s"File $activeFile does not exist") @@ -142,6 +172,9 @@ private[spark] object RollingFileAppender { val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 + val ENABLE_COMPRESSION = "spark.executor.logs.rolling.enableCompression" + + val GZIP_LOG_SUFFIX = ".gz" /** * Get the sorted list of rolled over files. This assumes that the all the rolled @@ -158,6 +191,6 @@ private[spark] object RollingFileAppender { val file = new File(directory, activeFileName).getAbsoluteFile if (file.exists) Some(file) else None } - rolledOverFiles ++ activeFile + rolledOverFiles.sortBy(_.getName.stripSuffix(GZIP_LOG_SUFFIX)) ++ activeFile } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java new file mode 100644 index 0000000000000..2c1a34a607592 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.RandomUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; + +import static org.junit.Assert.assertEquals; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class NioBufferedFileInputStreamSuite { + + private byte[] randomBytes; + + private File inputFile; + + @Before + public void setUp() throws IOException { + // Create a byte array of size 2 MB with random bytes + randomBytes = RandomUtils.nextBytes(2 * 1024 * 1024); + inputFile = File.createTempFile("temp-file", ".tmp"); + FileUtils.writeByteArrayToFile(inputFile, randomBytes); + } + + @After + public void tearDown() { + inputFile.delete(); + } + + @Test + public void testReadOneByte() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + } + + @Test + public void testReadMultipleBytes() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } + } + } + + @Test + public void testBytesSkipped() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + } + + @Test + public void testBytesSkippedAfterRead() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + } + + @Test + public void testNegativeBytesSkippedAfterRead() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + } + + @Test + public void testSkipFromFileChannel() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10); + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + } + + @Test + public void testBytesSkippedAfterEOF() throws IOException { + InputStream inputStream = new NioBufferedFileInputStream(inputFile); + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 9f94e36324536..b117c7709b46f 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -500,7 +500,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS } runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean => - val rdd = new BlockRDD[Int](sc, Array[BlockId]()) + val rdd = new BlockRDD[Int](sc, Array.empty[BlockId]) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) assert(rdd.isCheckpointedAndMaterialized === false) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 159b448e05b02..2b8b1805bc83f 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -79,7 +79,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -102,20 +102,20 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf conf.set("spark.ssl.enabled", "true") - conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.ui.keyStorePassword", "12345") conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.trustStore.isDefined === true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 2d48e75cfbd96..7093dad05c5f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -65,7 +65,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeMasterState") { val workers = Array(createWorkerInfo(), createWorkerInfo()) val activeApps = Array(createAppInfo()) - val completedApps = Array[ApplicationInfo]() + val completedApps = Array.empty[ApplicationInfo] val activeDrivers = Array(createDriverInfo()) val completedDrivers = Array(createDriverInfo()) val stateResponse = new MasterStateResponse( diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 732cbfaaeea46..7c649e305a37e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -91,7 +91,7 @@ class SparkSubmitSuite // scalastyle:off println test("prints usage on empty input") { - testPrematureExit(Array[String](), "Usage: spark-submit") + testPrematureExit(Array.empty[String], "Usage: spark-submit") } test("prints usage with only --help") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 34f27ecaa07a3..de321db845a66 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -33,7 +33,7 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { .set("spark.testing", "true") test("No Arguments Parsing") { - val argStrings = Array[String]() + val argStrings = Array.empty[String] val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get("spark.history.fs.logDirectory") === logDir.getAbsolutePath) assert(conf.get("spark.history.fs.updateInterval") === "1") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala index 72eaffb416981..4c3e96777940d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -22,16 +22,20 @@ import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) + val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) + when(webui.worker).thenReturn(worker) + when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 38b48a4c9e654..3b798e36b0499 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -57,7 +57,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite { } test("toArray()") { - val empty = ByteBuffer.wrap(Array[Byte]()) + val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes, empty)) assert(chunkedByteBuffer.toArray === bytes.array() ++ bytes.array()) @@ -74,7 +74,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite { } test("toInputStream()") { - val empty = ByteBuffer.wrap(Array[Byte]()) + val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 87600fe504b98..a757041299411 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -33,16 +33,21 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - createTaskSet(numTasks, 0, prefLocs: _*) + createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*) } def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): + TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(0, i, if (prefLocs.size != 0) prefLocs(i) else Nil) + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, stageAttemptId, 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 69edcf3347243..1b1a764ceff95 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -261,14 +261,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) clock.advance(LOCALITY_WAIT_MS) - // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should + // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 3) should // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) - // Offer host2, exec3 again, at NODE_LOCAL level: we should choose task 2 + // Offer host2, exec2, at NODE_LOCAL level: we should choose task 2 assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index == 1) - // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task + // Offer host2, exec2 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) clock.advance(LOCALITY_WAIT_MS) @@ -904,7 +904,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg task.index == index && !sched.endedTasks.contains(task.taskId) }.getOrElse { throw new RuntimeException(s"couldn't find index $index in " + - s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" + + s"tasks: ${tasks.map { t => t.index -> t.taskId }} with endedTasks:" + s" ${sched.endedTasks.keys}") } } @@ -974,6 +974,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.isZombie) } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock) + assert(manager.name === "TaskSet_0.0") + + // Make sure a task set with the same stage ID but different attempt ID has a unique name + val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 1) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock) + assert(manager2.name === "TaskSet_0.1") + + // Make sure a task set with the same attempt ID but different stage ID also has a unique name + val taskSet3 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 1) + val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock) + assert(manager3.name === "TaskSet_1.1") + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala new file mode 100644 index 0000000000000..64be966276140 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.Benchmark + +class KryoBenchmark extends SparkFunSuite { + val benchmark = new Benchmark("Benchmark Kryo Unsafe vs safe Serialization", 1024 * 1024 * 15, 10) + + ignore(s"Benchmark Kryo Unsafe vs safe Serialization") { + Seq (true, false).foreach (runBenchmark) + benchmark.run() + + // scalastyle:off + /* + Benchmark Kryo Unsafe vs safe Serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + basicTypes: Int with unsafe:true 151 / 170 104.2 9.6 1.0X + basicTypes: Long with unsafe:true 175 / 191 89.8 11.1 0.9X + basicTypes: Float with unsafe:true 177 / 184 88.8 11.3 0.9X + basicTypes: Double with unsafe:true 193 / 216 81.4 12.3 0.8X + Array: Int with unsafe:true 513 / 587 30.7 32.6 0.3X + Array: Long with unsafe:true 1211 / 1358 13.0 77.0 0.1X + Array: Float with unsafe:true 890 / 964 17.7 56.6 0.2X + Array: Double with unsafe:true 1335 / 1428 11.8 84.9 0.1X + Map of string->Double with unsafe:true 931 / 988 16.9 59.2 0.2X + basicTypes: Int with unsafe:false 197 / 217 79.9 12.5 0.8X + basicTypes: Long with unsafe:false 219 / 240 71.8 13.9 0.7X + basicTypes: Float with unsafe:false 208 / 217 75.7 13.2 0.7X + basicTypes: Double with unsafe:false 208 / 225 75.6 13.2 0.7X + Array: Int with unsafe:false 2559 / 2681 6.1 162.7 0.1X + Array: Long with unsafe:false 3425 / 3516 4.6 217.8 0.0X + Array: Float with unsafe:false 2025 / 2134 7.8 128.7 0.1X + Array: Double with unsafe:false 2241 / 2358 7.0 142.5 0.1X + Map of string->Double with unsafe:false 1044 / 1085 15.1 66.4 0.1X + */ + // scalastyle:on + } + + private def runBenchmark(useUnsafe: Boolean): Unit = { + def check[T: ClassTag](t: T, ser: SerializerInstance): Int = { + if (ser.deserialize[T](ser.serialize(t)) === t) 1 else 0 + } + + // Benchmark Primitives + val basicTypeCount = 1000000 + def basicTypes[T: ClassTag](name: String, gen: () => T): Unit = { + lazy val ser = createSerializer(useUnsafe) + val arrayOfBasicType: Array[T] = Array.fill(basicTypeCount)(gen()) + + benchmark.addCase(s"basicTypes: $name with unsafe:$useUnsafe") { _ => + var sum = 0L + var i = 0 + while (i < basicTypeCount) { + sum += check(arrayOfBasicType(i), ser) + i += 1 + } + sum + } + } + basicTypes("Int", Random.nextInt) + basicTypes("Long", Random.nextLong) + basicTypes("Float", Random.nextFloat) + basicTypes("Double", Random.nextDouble) + + // Benchmark Array of Primitives + val arrayCount = 10000 + def basicTypeArray[T: ClassTag](name: String, gen: () => T): Unit = { + lazy val ser = createSerializer(useUnsafe) + val arrayOfArrays: Array[Array[T]] = + Array.fill(arrayCount)(Array.fill[T](Random.nextInt(arrayCount))(gen())) + + benchmark.addCase(s"Array: $name with unsafe:$useUnsafe") { _ => + var sum = 0L + var i = 0 + while (i < arrayCount) { + val arr = arrayOfArrays(i) + sum += check(arr, ser) + i += 1 + } + sum + } + } + basicTypeArray("Int", Random.nextInt) + basicTypeArray("Long", Random.nextLong) + basicTypeArray("Float", Random.nextFloat) + basicTypeArray("Double", Random.nextDouble) + + // Benchmark Maps + val mapsCount = 1000 + lazy val ser = createSerializer(useUnsafe) + val arrayOfMaps: Array[Map[String, Double]] = Array.fill(mapsCount) { + Array.fill(Random.nextInt(mapsCount)) { + (Random.nextString(mapsCount / 10), Random.nextDouble()) + }.toMap + } + + benchmark.addCase(s"Map of string->Double with unsafe:$useUnsafe") { _ => + var sum = 0L + var i = 0 + while (i < mapsCount) { + val map = arrayOfMaps(i) + sum += check(map, ser) + i += 1 + } + sum + } + } + + def createSerializer(useUnsafe: Boolean): SerializerInstance = { + val conf = new SparkConf() + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + conf.set("spark.kryo.unsafe", useUnsafe.toString) + + new KryoSerializer(conf).newInstance() + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 57a82312008e9..5040841811054 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + conf.set("spark.kryo.unsafe", "false") test("SPARK-7392 configuration limits") { val kryoBufferProperty = "spark.kryoserializer.buffer" @@ -100,7 +101,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(Array("aaa", "bbb", null)) check(Array(true, false, true)) check(Array('a', 'b', 'c')) - check(Array[Int]()) + check(Array.empty[Int]) check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala similarity index 54% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala rename to core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala index 96e9054cd4876..d63a45ae4a6a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala @@ -15,23 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.serializer -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.QueryTest +class UnsafeKryoSerializerSuite extends KryoSerializerSuite { -class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { - test("table name with schema") { - // regression test for SPARK-11778 - spark.sql("create schema usrdb") - spark.sql("create table usrdb.test(c int)") - spark.read.table("usrdb.test") - spark.sql("drop table usrdb.test") - spark.sql("drop schema usrdb") + // This test suite should run all tests in KryoSerializerSuite with kryo unsafe. + + override def beforeAll() { + conf.set("spark.kryo.unsafe", "true") + super.beforeAll() } - test("SPARK-15887: hive-site.xml should be loaded") { - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - assert(hiveClient.getConf("hive.in.test", "") == "true") + override def afterAll() { + conf.set("spark.kryo.unsafe", "false") + super.afterAll() } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index fd12a21b7927e..e5d408a167361 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -194,6 +194,22 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } + withSpark(newSparkContext(killEnabled = true)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/jobs") + assert(hasKillLink) + } + } + + withSpark(newSparkContext(killEnabled = false)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/jobs") + assert(!hasKillLink) + } + } + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { @@ -453,20 +469,24 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } test("kill stage POST/GET response is correct") { - def getResponseCode(url: URL, method: String): Int = { - val connection = url.openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod(method) - connection.connect() - val code = connection.getResponseCode() - connection.disconnect() - code + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + // SPARK-6846: should be POST only but YARN AM doesn't proxy POST + getResponseCode(url, "GET") should be (200) + getResponseCode(url, "POST") should be (200) + } } + } + test("kill job POST/GET response is correct") { withSpark(newSparkContext(killEnabled = true)) { sc => sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -651,6 +671,17 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + try { + connection.connect() + connection.getResponseCode() + } finally { + connection.disconnect() + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 4fa9f9a8f590f..7e2da8e141532 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.util import java.io._ import java.nio.charset.StandardCharsets import java.util.concurrent.CountDownLatch +import java.util.zip.GZIPInputStream import scala.collection.mutable.HashSet import scala.reflect._ import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.log4j.{Appender, Level, Logger} import org.apache.log4j.spi.LoggingEvent import org.mockito.ArgumentCaptor @@ -72,6 +74,25 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis) } + test("rolling file appender - time-based rolling (compressed)") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverIntervalMillis = 100 + val durationMillis = 1000 + val numRollovers = durationMillis / rolloverIntervalMillis + val textToAppend = (1 to numRollovers).map( _.toString * 10 ) + + val sparkConf = new SparkConf() + sparkConf.set("spark.executor.logs.rolling.enableCompression", "true") + val appender = new RollingFileAppender(testInputStream, testFile, + new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false), + sparkConf, 10) + + testRolling( + appender, testOutputStream, textToAppend, rolloverIntervalMillis, isCompressed = true) + } + test("rolling file appender - size-based rolling") { // setup input stream and appender val testOutputStream = new PipedOutputStream() @@ -89,6 +110,25 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { } } + test("rolling file appender - size-based rolling (compressed)") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverSize = 1000 + val textToAppend = (1 to 3).map( _.toString * 1000 ) + + val sparkConf = new SparkConf() + sparkConf.set("spark.executor.logs.rolling.enableCompression", "true") + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(rolloverSize, false), sparkConf, 99) + + val files = testRolling(appender, testOutputStream, textToAppend, 0, isCompressed = true) + files.foreach { file => + logInfo(file.toString + ": " + file.length + " bytes") + assert(file.length < rolloverSize) + } + } + test("rolling file appender - cleaning") { // setup input stream and appender val testOutputStream = new PipedOutputStream() @@ -273,7 +313,8 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { appender: FileAppender, outputStream: OutputStream, textToAppend: Seq[String], - sleepTimeBetweenTexts: Long + sleepTimeBetweenTexts: Long, + isCompressed: Boolean = false ): Seq[File] = { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") @@ -290,10 +331,23 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // verify whether all the data written to rolled over files is same as expected val generatedFiles = RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName) - logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) + logInfo("Generate files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) + if (isCompressed) { + assert( + generatedFiles.filter(_.getName.endsWith(RollingFileAppender.GZIP_LOG_SUFFIX)).size > 0) + } val allText = generatedFiles.map { file => - Files.toString(file, StandardCharsets.UTF_8) + if (file.getName.endsWith(RollingFileAppender.GZIP_LOG_SUFFIX)) { + val inputStream = new GZIPInputStream(new FileInputStream(file)) + try { + IOUtils.toString(inputStream, StandardCharsets.UTF_8) + } finally { + IOUtils.closeQuietly(inputStream) + } + } else { + Files.toString(file, StandardCharsets.UTF_8) + } }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index bc28b2d9cb831..15ef32f21d90c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -25,11 +25,13 @@ import java.nio.charset.StandardCharsets import java.text.DecimalFormatSymbols import java.util.Locale import java.util.concurrent.TimeUnit +import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer import scala.util.Random import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.commons.lang3.SystemUtils import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration @@ -274,65 +276,109 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11" + sep + "00 h") } - test("reading offset bytes of a file") { + def getSuffix(isCompressed: Boolean): String = { + if (isCompressed) { + ".gz" + } else { + "" + } + } + + def writeLogFile(path: String, content: Array[Byte]): Unit = { + val outputStream = if (path.endsWith(".gz")) { + new GZIPOutputStream(new FileOutputStream(path)) + } else { + new FileOutputStream(path) + } + IOUtils.write(content, outputStream) + outputStream.close() + content.size + } + + private val workerConf = new SparkConf() + + def testOffsetBytes(isCompressed: Boolean): Unit = { val tmpDir2 = Utils.createTempDir() - val f1Path = tmpDir2 + "/f1" - val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) - f1.close() + val suffix = getSuffix(isCompressed) + val f1Path = tmpDir2 + "/f1" + suffix + writeLogFile(f1Path, "1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) + val f1Length = Utils.getFileLength(new File(f1Path), workerConf) // Read first few bytes - assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") + assert(Utils.offsetBytes(f1Path, f1Length, 0, 5) === "1\n2\n3") // Read some middle bytes - assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") + assert(Utils.offsetBytes(f1Path, f1Length, 4, 11) === "3\n4\n5\n6") // Read last few bytes - assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, 12, 18) === "7\n8\n9\n") // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") + assert(Utils.offsetBytes(f1Path, f1Length, -5, 5) === "1\n2\n3") // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, 12, 22) === "7\n8\n9\n") // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") Utils.deleteRecursively(tmpDir2) } - test("reading offset bytes across multiple files") { + test("reading offset bytes of a file") { + testOffsetBytes(isCompressed = false) + } + + test("reading offset bytes of a file (compressed)") { + testOffsetBytes(isCompressed = true) + } + + def testOffsetBytesMultipleFiles(isCompressed: Boolean): Unit = { val tmpDir = Utils.createTempDir() - val files = (1 to 3).map(i => new File(tmpDir, i.toString)) - Files.write("0123456789", files(0), StandardCharsets.UTF_8) - Files.write("abcdefghij", files(1), StandardCharsets.UTF_8) - Files.write("ABCDEFGHIJ", files(2), StandardCharsets.UTF_8) + val suffix = getSuffix(isCompressed) + val files = (1 to 3).map(i => new File(tmpDir, i.toString + suffix)) :+ new File(tmpDir, "4") + writeLogFile(files(0).getAbsolutePath, "0123456789".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(1).getAbsolutePath, "abcdefghij".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(2).getAbsolutePath, "ABCDEFGHIJ".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(3).getAbsolutePath, "9876543210".getBytes(StandardCharsets.UTF_8)) + val fileLengths = files.map(Utils.getFileLength(_, workerConf)) // Read first few bytes in the 1st file - assert(Utils.offsetBytes(files, 0, 5) === "01234") + assert(Utils.offsetBytes(files, fileLengths, 0, 5) === "01234") // Read bytes within the 1st file - assert(Utils.offsetBytes(files, 5, 8) === "567") + assert(Utils.offsetBytes(files, fileLengths, 5, 8) === "567") // Read bytes across 1st and 2nd file - assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh") + assert(Utils.offsetBytes(files, fileLengths, 8, 18) === "89abcdefgh") // Read bytes across 1st, 2nd and 3rd file - assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD") + assert(Utils.offsetBytes(files, fileLengths, 5, 24) === "56789abcdefghijABCD") + + // Read bytes across 3rd and 4th file + assert(Utils.offsetBytes(files, fileLengths, 25, 35) === "FGHIJ98765") // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh") + assert(Utils.offsetBytes(files, fileLengths, -5, 18) === "0123456789abcdefgh") // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ") + assert(Utils.offsetBytes(files, fileLengths, 18, 45) === "ijABCDEFGHIJ9876543210") // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ") + assert(Utils.offsetBytes(files, fileLengths, -5, 45) === + "0123456789abcdefghijABCDEFGHIJ9876543210") Utils.deleteRecursively(tmpDir) } + test("reading offset bytes across multiple files") { + testOffsetBytesMultipleFiles(isCompressed = false) + } + + test("reading offset bytes across multiple files (compressed)") { + testOffsetBytesMultipleFiles(isCompressed = true) + } + test("deserialize long value") { val testval : Long = 9730889947L val bbuf = ByteBuffer.allocate(8) @@ -350,6 +396,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.getIteratorSize(iterator) === 5L) } + test("getIteratorZipWithIndex") { + val iterator = Utils.getIteratorZipWithIndex(Iterator(0, 1, 2), -1L + Int.MaxValue) + assert(iterator.toArray === Array( + (0, -1L + Int.MaxValue), (1, 0L + Int.MaxValue), (2, 1L + Int.MaxValue) + )) + intercept[IllegalArgumentException] { + Utils.getIteratorZipWithIndex(Iterator(0, 1, 2), -1L) + } + } + test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() @@ -790,14 +846,11 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("Set Spark CallerContext") { val context = "test" - try { + new CallerContext(context).setCurrentContext() + if (CallerContext.callerContextSupported) { val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") - assert(new CallerContext(context).setCurrentContext()) assert(s"SPARK_$context" === callerContext.getMethod("getCurrent").invoke(null).toString) - } catch { - case e: ClassNotFoundException => - assert(!new CallerContext(context).setCurrentContext()) } } diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2a568cc8010db..dfe4eb9f8bc65 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -40,7 +40,7 @@ commons-digester-1.8.jar commons-httpclient-3.1.jar commons-io-2.4.jar commons-lang-2.6.jar -commons-lang3-3.3.2.jar +commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar commons-net-2.2.jar @@ -157,7 +157,7 @@ parquet-jackson-1.9.0-palantir2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.3.jar +py4j-0.10.4.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index f5acee6b90059..ebe46a42a15c6 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -217,9 +217,8 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Note that tests should not be run as root or an admin user. -Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: +The following is an example of a command to run the tests: - ./build/mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package ./build/mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test The ScalaTest plugin also supports running only a specific Scala test suite as follows: @@ -233,9 +232,8 @@ or a Java test: ## Testing with SBT -Some of the tests require Spark to be packaged first, so always run `build/sbt package` the first time. The following is an example of a correct (build, test) sequence: +The following is an example of a command to run the tests: - ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: diff --git a/docs/configuration.md b/docs/configuration.md index 373e22d71a872..780fc94908d38 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -293,6 +293,14 @@ Apart from these, the following properties are also available, and may be useful Older log files will be deleted. Disabled by default. + + spark.executor.logs.rolling.enableCompression + false + + Enable executor log compression. If it is enabled, the rolled executor logs will be compressed. + Disabled by default. + + spark.executor.logs.rolling.maxSize (none) @@ -624,7 +632,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.killEnabled true - Allows stages and corresponding jobs to be killed from the web ui. + Allows jobs and stages to be killed from the web UI. @@ -791,6 +799,14 @@ Apart from these, the following properties are also available, and may be useful See the tuning guide for more details. + + spark.kryo.unsafe + false + + Whether to use unsafe based Kryo serializer. Can be + substantially faster by using Unsafe Based IO. + + spark.kryoserializer.buffer.max 64m @@ -1874,6 +1890,21 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.r.shell.command is used for sparkR shell while spark.r.driver.command is used for running R script. + + spark.r.backendConnectionTimeout + 6000 + + Connection timeout set by R process on its connection to RBackend in seconds. + + + + spark.r.heartBeatInterval + 100 + + Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout. + + + #### Deploy diff --git a/docs/ml-features.md b/docs/ml-features.md index a7f710fa52e64..64c6a160239cc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1103,11 +1103,16 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible -that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that NaN values are -handled specially and placed into their own bucket. For example, if 4 buckets are used, then -non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. -The bin ranges are chosen using an approximate algorithm (see the documentation for +that the number of buckets used will be smaller than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. + +NaN values: Note also that QuantileDiscretizer +will raise an error when it finds NaN values in the dataset, but the user can also choose to either +keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep +NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets +are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. + +Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the `relativeError` parameter. When set to zero, exact quantiles are calculated diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 20b4bee0f58e1..7516579ec6dbf 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,7 +182,7 @@ variable called `sc`. Making your own SparkContext will not work. You can set wh context connects to using the `--master` argument, and you can add JARs to the classpath by passing a comma-separated list to the `--jars` argument. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates -to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. Sonatype) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -214,9 +214,9 @@ variable called `sc`. Making your own SparkContext will not work. You can set wh context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates -to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in -the requirements.txt of that package) must be manually installed using pip when necessary. +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. Sonatype) +can be passed to the `--repositories` argument. Any Python dependencies a Spark package has (listed in +the requirements.txt of that package) must be manually installed using `pip` when necessary. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 7b82b957d5299..1c0b60f7b9346 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -250,6 +250,15 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.worker.ui.compressedLogFileLengthCacheSize + 100 + + For compressed log files, the uncompressed file can only be computed by uncompressing the files. + Spark caches the uncompressed file size of compressed log files. This property controls the cache + size. + + # Connecting an Application to the Cluster diff --git a/docs/sparkr.md b/docs/sparkr.md index 340e7f7cb1a0b..f30bd4026fed3 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -135,7 +135,7 @@ sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") {% endhighlight %} -We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a consequence, a regular multi-line JSON file will most often fail.
{% highlight r %} @@ -591,3 +591,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The method `registerTempTable` has been deprecated to be replaced by `createOrReplaceTempView`. - The method `dropTempTable` has been deprecated to be replaced by `dropTempView`. - The `sc` SparkContext parameter is no longer required for these functions: `setJobGroup`, `clearJobGroup`, `cancelJobGroup` + +## Upgrading to SparkR 2.1.0 + + - `join` no longer performs Cartesian Product by default, use `crossJoin` instead. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d0f43ab0a9cc9..b9be7a7545ef8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -140,7 +140,7 @@ As an example, the following creates a DataFrame based on the content of a JSON ## Untyped Dataset Operations (aka DataFrame Operations) -DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset), [Java](api/java/index.html?org/apache/spark/sql/Dataset.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/DataFrame.html). +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset), [Java](api/java/index.html?org/apache/spark/sql/Dataset.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/SparkDataFrame.html). As mentioned above, in Spark 2.0, DataFrames are just Dataset of `Row`s in Scala and Java API. These operations are also referred as "untyped transformations" in contrast to "typed transformations" come with strongly typed Scala/Java Datasets. @@ -316,7 +316,7 @@ Serializable and has getters and setters for all of its fields. Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by sampling the whole datase, similar to the inference that is performed on JSON files. +and the types are inferred by sampling the whole dataset, similar to the inference that is performed on JSON files. {% include_example schema_inferring python/sql/basic.py %}
@@ -422,8 +422,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames loaded from any data source type can be converted into other types -using this syntax. +names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data +source type can be converted into other types using this syntax.
@@ -832,8 +832,9 @@ This conversion can be done using `SparkSession.read.json()` on either an RDD of or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
@@ -844,8 +845,9 @@ This conversion can be done using `SparkSession.read().json()` on either an RDD or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
@@ -855,8 +857,9 @@ Spark SQL can automatically infer the schema of a JSON dataset and load it as a This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset python/sql/datasource.py %} @@ -867,8 +870,9 @@ the `read.json()` function, which loads data from a directory of JSON files wher files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset r/RSparkSQLExample.R %} @@ -904,50 +908,27 @@ access data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), and `hdfs-site.xml` (for HDFS configuration) file in `conf/`. -
- -
- When working with Hive, one must instantiate `SparkSession` with Hive support, including connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined functions. Users who do not have an existing Hive deployment can still enable Hive support. When not configured by the `hive-site.xml`, the context automatically creates `metastore_db` in the current directory and creates a directory configured by `spark.sql.warehouse.dir`, which defaults to the directory -`spark-warehouse` in the current directory that the spark application is started. Note that +`spark-warehouse` in the current directory that the Spark application is started. Note that the `hive.metastore.warehouse.dir` property in `hive-site.xml` is deprecated since Spark 2.0.0. Instead, use `spark.sql.warehouse.dir` to specify the default location of database in warehouse. -You may need to grant write privilege to the user who starts the spark application. +You may need to grant write privilege to the user who starts the Spark application. +
+ +
{% include_example spark_hive scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala %}
- -When working with Hive, one must instantiate `SparkSession` with Hive support, including -connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined functions. -Users who do not have an existing Hive deployment can still enable Hive support. When not configured -by the `hive-site.xml`, the context automatically creates `metastore_db` in the current directory and -creates a directory configured by `spark.sql.warehouse.dir`, which defaults to the directory -`spark-warehouse` in the current directory that the spark application is started. Note that -the `hive.metastore.warehouse.dir` property in `hive-site.xml` is deprecated since Spark 2.0.0. -Instead, use `spark.sql.warehouse.dir` to specify the default location of database in warehouse. -You may need to grant write privilege to the user who starts the spark application. - {% include_example spark_hive java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java %}
- -When working with Hive, one must instantiate `SparkSession` with Hive support, including -connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined functions. -Users who do not have an existing Hive deployment can still enable Hive support. When not configured -by the `hive-site.xml`, the context automatically creates `metastore_db` in the current directory and -creates a directory configured by `spark.sql.warehouse.dir`, which defaults to the directory -`spark-warehouse` in the current directory that the spark application is started. Note that -the `hive.metastore.warehouse.dir` property in `hive-site.xml` is deprecated since Spark 2.0.0. -Instead, use `spark.sql.warehouse.dir` to specify the default location of database in warehouse. -You may need to grant write privilege to the user who starts the spark application. - {% include_example spark_hive python/sql/hive.py %}
@@ -998,7 +979,7 @@ The following options can be used to configure the version of Hive that is used
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure - they are packaged with you application.
  • + they are packaged with your application. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 767e1f9402e01..a5d36da5b6de9 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -115,11 +115,11 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} - (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.5/commons-lang3-3.5.jar)). groupId = org.apache.commons artifactId = commons-lang3 - version = 3.3.2 + version = 3.5 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index 456b8453383db..c1ef396907db7 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -8,9 +8,9 @@ The Spark Streaming integration for Kafka 0.10 is similar in design to the 0.8 [ ### Linking For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - groupId = org.apache.spark - artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 @@ -44,10 +44,47 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
    + import java.util.*; + import org.apache.spark.SparkConf; + import org.apache.spark.TaskContext; + import org.apache.spark.api.java.*; + import org.apache.spark.api.java.function.*; + import org.apache.spark.streaming.api.java.*; + import org.apache.spark.streaming.kafka010.*; + import org.apache.kafka.clients.consumer.ConsumerRecord; + import org.apache.kafka.common.TopicPartition; + import org.apache.kafka.common.serialization.StringDeserializer; + import scala.Tuple2; + + Map kafkaParams = new HashMap<>(); + kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); + kafkaParams.put("key.deserializer", StringDeserializer.class); + kafkaParams.put("value.deserializer", StringDeserializer.class); + kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); + kafkaParams.put("auto.offset.reset", "latest"); + kafkaParams.put("enable.auto.commit", false); + + Collection topics = Arrays.asList("topicA", "topicB"); + + final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + + stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + })
    For possible kafkaParams, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +If your Spark batch duration is larger than the default Kafka heartbeat session timeout (30 seconds), increase heartbeat.interval.ms and session.timeout.ms appropriately. For batches larger than 5 minutes, this will require changing group.max.session.timeout.ms on the broker. Note that the example sets enable.auto.commit to false, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. ### LocationStrategies @@ -84,6 +121,20 @@ If you have a use case that is better suited to batch processing, you can create
    + // Import dependencies and create kafka params as in Create Direct Stream above + + OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) + }; + + JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() + );
    @@ -102,6 +153,20 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no }
    + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } + });
    @@ -119,15 +184,24 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
    stream.foreachRDD { rdd => - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsets) + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) } As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
    + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } + });
    @@ -140,7 +214,7 @@ For data stores that support transactions, saving offsets in the same transactio // begin from the the offsets committed to the database val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic")), resultSet.int("partition")) -> resultSet.long("offset") + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") }.toMap val stream = KafkaUtils.createDirectStream[String, String]( @@ -154,16 +228,46 @@ For data stores that support transactions, saving offsets in the same transactio val results = yourCalculation(rdd) - yourTransactionBlock { - // update results + // begin your transaction - // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // assert that offsets were updated correctly - } + // end your transaction }
    + // The details depend on your data store, but the general idea looks like this + + // begin from the the offsets committed to the database + Map fromOffsets = new HashMap<>(); + for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); + } + + JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) + ); + + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } + });
    @@ -184,6 +288,14 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html )
    + Map kafkaParams = new HashMap(); + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + kafkaParams.put("security.protocol", "SSL"); + kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); + kafkaParams.put("ssl.truststore.password", "test1234"); + kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); + kafkaParams.put("ssl.keystore.password", "test1234"); + kafkaParams.put("ssl.key.password", "test1234");
    diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 668489addf82c..a6c3b3a9024d8 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -150,16 +150,25 @@ The following options must be set for the Kafka source. + + + + + - + - @@ -174,16 +183,21 @@ The following configurations are optional:
    Optionvaluemeaning
    assignjson string {"topicA":[0,1],"topicB":[2,4]}Specific TopicPartitions to consume. + Only one of "assign", "subscribe" or "subscribePattern" + options can be specified for Kafka source.
    subscribe A comma-separated list of topicsThe topic list to subscribe. Only one of "subscribe" and "subscribePattern" options can be - specified for Kafka source.The topic list to subscribe. + Only one of "assign", "subscribe" or "subscribePattern" + options can be specified for Kafka source.
    subscribePattern Java regex stringThe pattern used to subscribe the topic. Only one of "subscribe" and "subscribePattern" + The pattern used to subscribe to topic(s). + Only one of "assign, "subscribe" or "subscribePattern" options can be specified for Kafka source.
    - - - - + + + + - + + + + + + +
    Optionvaluedefaultmeaning
    startingOffset["earliest", "latest"]"latest"The start point when a query is started, either "earliest" which is from the earliest offset, - or "latest" which is just from the latest offset. Note: This only applies when a new Streaming q - uery is started, and that resuming will always pick up from where the query left off.startingOffsetsearliest, latest, or json string + {"topicA":{"0":23,"1":-1},"topicB":{"0":-2}} + latestThe start point when a query is started, either "earliest" which is from the earliest offsets, + "latest" which is just from the latest offsets, or a json string specifying a starting offset for + each TopicPartition. In the json, -2 as an offset can be used to refer to earliest, -1 to latest. + Note: This only applies when a new Streaming query is started, and that resuming will always pick + up from where the query left off. Newly discovered partitions during a query will start at + earliest.
    failOnDataLoss[true, false]true or false true Whether to fail the query when it's possible that data is lost (e.g., topics are deleted, or offsets are out of range). This may be a false alarm. You can disable it when it doesn't work @@ -207,6 +221,12 @@ The following configurations are optional: 10 milliseconds to wait before retrying to fetch Kafka offsets
    maxOffsetsPerTriggerlongnoneRate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume.
    Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, @@ -215,10 +235,10 @@ Kafka's own configurations can be set via `DataStreamReader.option` with `kafka. Note that the following Kafka params cannot be set and the Kafka source will throw an exception: - **group.id**: Kafka source will create a unique group id for each query automatically. -- **auto.offset.reset**: Set the source option `startingOffset` to `earliest` or `latest` to specify +- **auto.offset.reset**: Set the source option `startingOffsets` to specify where to start instead. Structured Streaming manages which offsets are consumed internally, rather than rely on the kafka Consumer to do it. This will ensure that no data is missed when when new - topics/partitions are dynamically subscribed. Note that `startingOffset` only applies when a new + topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new Streaming query is started, and that resuming will always pick up from where the query left off. - **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the keys. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 6fe3049995876..b738194eac9aa 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -190,6 +190,8 @@ is handled automatically, and with Spark standalone, automatic cleanup can be co Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates with `--packages`. All transitive dependencies will be handled when using this command. Additional repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +(Note that credentials for password-protected repositories can be supplied in some cases in the repository URI, +such as in `https://user:password@host/...`. Be careful when supplying credentials this way.) These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 76dd160d5568b..052153c9e9736 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -56,7 +56,7 @@ public void setValue(String value) { public static void main(String[] args) { // $example on:spark_hive$ // warehouseLocation points to the default location for managed databases and tables - String warehouseLocation = "file:" + System.getProperty("user.dir") + "spark-warehouse"; + String warehouseLocation = "spark-warehouse"; SparkSession spark = SparkSession .builder() .appName("Java Spark Hive Example") diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index 907eec67a0eb5..db7054307c2e3 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -84,10 +84,10 @@ # Prepare test documents, which are unlabeled. test = spark.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") + (4, "spark i j k"), + (5, "l m n"), + (6, "mapreduce spark"), + (7, "apache hadoop") ], ["id", "text"]) # Make predictions on test documents. cvModel uses the best model found (lrModel). diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index 8ad450b669fc9..e4a0d314e9d91 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -38,7 +38,7 @@ # loads data dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") - gmm = GaussianMixture().setK(2).setSeed(538009335L) + gmm = GaussianMixture().setK(2).setSeed(538009335) model = gmm.fit(dataset) print("Gaussians shown as a DataFrame: ") diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py index f63e4db434222..e1fab7cbe6d80 100644 --- a/examples/src/main/python/ml/pipeline_example.py +++ b/examples/src/main/python/ml/pipeline_example.py @@ -35,10 +35,10 @@ # $example on$ # Prepare training documents from a list of (id, text, label) tuples. training = spark.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0) + (0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0) ], ["id", "text", "label"]) # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. @@ -52,10 +52,10 @@ # Prepare test documents, which are unlabeled (id, text) tuples. test = spark.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "spark hadoop spark"), - (7L, "apache hadoop") + (4, "spark i j k"), + (5, "l m n"), + (6, "spark hadoop spark"), + (7, "apache hadoop") ], ["id", "text"]) # Make predictions on test documents and print columns of interest. diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py index daf000e38dcd0..91f8378f29c0c 100644 --- a/examples/src/main/python/mllib/binary_classification_metrics_example.py +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -39,7 +39,7 @@ .rdd.map(lambda row: LabeledPoint(row[0], row[1])) # Split data into training (60%) and test (40%) - training, test = data.randomSplit([0.6, 0.4], seed=11L) + training, test = data.randomSplit([0.6, 0.4], seed=11) training.cache() # Run training algorithm to build the model diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py index cd56b3c97c778..7dc5fb4f9127f 100644 --- a/examples/src/main/python/mllib/multi_class_metrics_example.py +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -32,7 +32,7 @@ data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") # Split data into training (60%) and test (40%) - training, test = data.randomSplit([0.6, 0.4], seed=11L) + training, test = data.randomSplit([0.6, 0.4], seed=11) training.cache() # Run training algorithm to build the model diff --git a/examples/src/main/python/mllib/tf_idf_example.py b/examples/src/main/python/mllib/tf_idf_example.py index c4d53333a95a9..b66412b2334e7 100644 --- a/examples/src/main/python/mllib/tf_idf_example.py +++ b/examples/src/main/python/mllib/tf_idf_example.py @@ -43,7 +43,7 @@ # In such cases, the IDF for these terms is set to 0. # This feature can be used by passing the minDocFreq value to the IDF constructor. idfIgnore = IDF(minDocFreq=2).fit(tf) - tfidfIgnore = idf.transform(tf) + tfidfIgnore = idfIgnore.transform(tf) # $example off$ print("tfidf:") diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 98b48908b5a12..ad83fe1cf14b5 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -34,7 +34,7 @@ if __name__ == "__main__": # $example on:spark_hive$ # warehouse_location points to the default location for managed databases and tables - warehouse_location = 'file:${system:user.dir}/spark-warehouse' + warehouse_location = 'spark-warehouse' spark = SparkSession \ .builder \ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala deleted file mode 100644 index 90b817b23e156..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.clustering.GaussianMixture -import org.apache.spark.mllib.linalg.Vectors - -/** - * An example Gaussian Mixture Model EM app. Run with - * {{{ - * ./bin/run-example mllib.DenseGaussianMixture - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DenseGaussianMixture { - def main(args: Array[String]): Unit = { - if (args.length < 3) { - println("usage: DenseGmmEM [maxIterations]") - } else { - val maxIterations = if (args.length > 3) args(3).toInt else 100 - run(args(0), args(1).toInt, args(2).toDouble, maxIterations) - } - } - - private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { - val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - - val data = ctx.textFile(inputFile).map { line => - Vectors.dense(line.trim.split(' ').map(_.toDouble)) - }.cache() - - val clusters = new GaussianMixture() - .setK(k) - .setConvergenceTol(convergenceTol) - .setMaxIterations(maxIterations) - .run(data) - - for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) - } - - println("The membership value of each vector to all mixture components (first <= 100):") - val membership = clusters.predictSoft(data) - membership.take(100).foreach { x => - print(" " + x.mkString(",")) - } - println() - println("Cluster labels (first <= 100):") - val clusterLabels = clusters.predict(data) - clusterLabels.take(100).foreach { x => - print(" " + x) - } - println() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala deleted file mode 100644 index e5592966f13fa..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.SparkConf -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} -import org.apache.spark.streaming.{Seconds, StreamingContext} - -/** - * Train a linear regression model on one stream of data and make predictions - * on another stream, where the data streams arrive as text files - * into two different directories. - * - * The rows of the text files must be labeled data points in the form - * `(y,[x1,x2,x3,...,xn])` - * Where n is the number of features. n must be the same for train and test. - * - * Usage: StreamingLinearRegression - * - * To run on your local machine using the two directories `trainingDir` and `testDir`, - * with updates every 5 seconds, and 2 features per data point, call: - * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2 - * - * As you add text files to `trainingDir` the model will continuously update. - * Anytime you add text files to `testDir`, you'll see predictions from the current model. - * - */ -object StreamingLinearRegression { - - def main(args: Array[String]) { - - if (args.length != 4) { - System.err.println( - "Usage: StreamingLinearRegression ") - System.exit(1) - } - - val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") - val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - - val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) - val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) - - val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.zeros(args(3).toInt)) - - model.trainOn(trainingData) - model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - - ssc.start() - ssc.awaitTermination() - - } - -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala index 0a1cd2d62d5b5..2ba1a62e450ee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala @@ -26,6 +26,25 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD // $example off$ import org.apache.spark.streaming._ +/** + * Train a linear regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features. n must be the same for train and test. + * + * Usage: StreamingLinearRegressionExample + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example mllib.StreamingLinearRegressionExample trainingDir testDir + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ object StreamingLinearRegressionExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 11e84c0e45632..ded18dacf1fe3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -38,7 +38,7 @@ object SparkHiveExample { // $example on:spark_hive$ // warehouseLocation points to the default location for managed databases and tables - val warehouseLocation = "file:${system:user.dir}/spark-warehouse" + val warehouseLocation = "spark-warehouse" val spark = SparkSession .builder() diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala new file mode 100644 index 0000000000000..40d568a12c25d --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.io.Writer + +import scala.collection.mutable.HashMap +import scala.util.control.NonFatal + +import org.apache.kafka.common.TopicPartition +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +/** + * Utilities for converting Kafka related objects to and from json. + */ +private object JsonUtils { + private implicit val formats = Serialization.formats(NoTypeHints) + + /** + * Read TopicPartitions from json string + */ + def partitions(str: String): Array[TopicPartition] = { + try { + Serialization.read[Map[String, Seq[Int]]](str).flatMap { case (topic, parts) => + parts.map { part => + new TopicPartition(topic, part) + } + }.toArray + } catch { + case NonFatal(x) => + throw new IllegalArgumentException( + s"""Expected e.g. {"topicA":[0,1],"topicB":[0,1]}, got $str""") + } + } + + /** + * Write TopicPartitions as json string + */ + def partitions(partitions: Iterable[TopicPartition]): String = { + val result = new HashMap[String, List[Int]] + partitions.foreach { tp => + val parts: List[Int] = result.getOrElse(tp.topic, Nil) + result += tp.topic -> (tp.partition::parts) + } + Serialization.write(result) + } + + /** + * Read per-TopicPartition offsets from json string + */ + def partitionOffsets(str: String): Map[TopicPartition, Long] = { + try { + Serialization.read[Map[String, Map[Int, Long]]](str).flatMap { case (topic, partOffsets) => + partOffsets.map { case (part, offset) => + new TopicPartition(topic, part) -> offset + } + }.toMap + } catch { + case NonFatal(x) => + throw new IllegalArgumentException( + s"""Expected e.g. {"topicA":{"0":23,"1":-1},"topicB":{"0":-2}}, got $str""") + } + } + + /** + * Write per-TopicPartition offsets as json string + */ + def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { + val result = new HashMap[String, HashMap[Int, Long]]() + partitionOffsets.foreach { case (tp, off) => + val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) + parts += tp.partition -> off + result += tp.topic -> parts + } + Serialization.write(result) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1be70db87497e..61cba737d148a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -22,7 +22,7 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer} +import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener import org.apache.kafka.common.TopicPartition @@ -82,6 +82,7 @@ private[kafka010] case class KafkaSource( executorKafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, + startingOffsets: StartingOffsets, failOnDataLoss: Boolean) extends Source with Logging { @@ -95,6 +96,9 @@ private[kafka010] case class KafkaSource( private val offsetFetchAttemptIntervalMs = sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + private val maxOffsetsPerTrigger = + sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) + /** * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. @@ -109,13 +113,19 @@ private[kafka010] case class KafkaSource( private lazy val initialPartitionOffsets = { val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) metadataLog.get(0).getOrElse { - val offsets = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = false)) + val offsets = startingOffsets match { + case EarliestOffsets => KafkaSourceOffset(fetchEarliestOffsets()) + case LatestOffsets => KafkaSourceOffset(fetchLatestOffsets()) + case SpecificOffsets(p) => KafkaSourceOffset(fetchSpecificStartingOffsets(p)) + } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") offsets }.partitionToOffsets } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None + override def schema: StructType = KafkaSource.kafkaSchema /** Returns the maximum available offset for this source. */ @@ -123,9 +133,54 @@ private[kafka010] case class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets - val offset = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = true)) - logDebug(s"GetOffset: ${offset.partitionToOffsets.toSeq.map(_.toString).sorted}") - Some(offset) + val latest = fetchLatestOffsets() + val offsets = maxOffsetsPerTrigger match { + case None => + latest + case Some(limit) if currentPartitionOffsets.isEmpty => + rateLimit(limit, initialPartitionOffsets, latest) + case Some(limit) => + rateLimit(limit, currentPartitionOffsets.get, latest) + } + + currentPartitionOffsets = Some(offsets) + logDebug(s"GetOffset: ${offsets.toSeq.map(_.toString).sorted}") + Some(KafkaSourceOffset(offsets)) + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: Map[TopicPartition, Long], + until: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + val fromNew = fetchNewPartitionEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + logDebug(s"rateLimit $tp prorated amount is $prorate") + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + logDebug(s"rateLimit $tp new offset is $off") + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } } /** @@ -148,11 +203,7 @@ private[kafka010] case class KafkaSource( // Find the new partitions, and get their earliest offsets val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet) - val newPartitionOffsets = if (newPartitions.nonEmpty) { - fetchNewPartitionEarliestOffsets(newPartitions.toSeq) - } else { - Map.empty[TopicPartition, Long] - } + val newPartitionOffsets = fetchNewPartitionEarliestOffsets(newPartitions.toSeq) if (newPartitionOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) @@ -216,6 +267,12 @@ private[kafka010] case class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) + + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + sqlContext.createDataFrame(rdd, schema) } @@ -227,53 +284,102 @@ private[kafka010] case class KafkaSource( override def toString(): String = s"KafkaSource[$consumerStrategy]" /** - * Fetch the offset of a partition, either seek to the latest offsets or use the current offsets - * in the consumer. + * Set consumer position to specified offsets, making sure all assignments are set. */ - private def fetchPartitionOffsets( - seekToEnd: Boolean): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { - // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) - assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + private def fetchSpecificStartingOffsets( + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + val result = withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + assert(partitions.asScala == partitionOffsets.keySet, + "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + + "Use -1 for latest, -2 for earliest, if you don't care.\n" + + s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") + + partitionOffsets.foreach { + case (tp, -1) => consumer.seekToEnd(ju.Arrays.asList(tp)) + case (tp, -2) => consumer.seekToBeginning(ju.Arrays.asList(tp)) + case (tp, off) => consumer.seek(tp, off) + } + partitionOffsets.map { + case (tp, _) => tp -> consumer.position(tp) + } + } + partitionOffsets.foreach { + case (tp, off) if off != -1 && off != -2 => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + result + } + + /** + * Fetch the earliest offsets of partitions. + */ + private def fetchEarliestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() consumer.pause(partitions) - logDebug(s"Partitioned assigned to consumer: $partitions") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning") - // Get the current or latest offset of each partition - if (seekToEnd) { - consumer.seekToEnd(partitions) - logDebug("Seeked to the end") - } + consumer.seekToBeginning(partitions) val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got offsets for partition : $partitionOffsets") + logDebug(s"Got earliest offsets for partition : $partitionOffsets") partitionOffsets } /** - * Fetch the earliest offsets for newly discovered partitions. The return result may not contain - * some partitions if they are deleted. + * Fetch the latest offset of partitions. */ - private def fetchNewPartitionEarliestOffsets( - newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { - // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) - assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + private def fetchLatestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() - logDebug(s"\tPartitioned assigned to consumer: $partitions") + consumer.pause(partitions) + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") - // Get the earliest offset of each partition - consumer.seekToBeginning(partitions) - val partitionToOffsets = newPartitions.filter { p => - // When deleting topics happen at the same time, some partitions may not be in `partitions`. - // So we need to ignore them - partitions.contains(p) - }.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got offsets for new partitions: $partitionToOffsets") - partitionToOffsets + consumer.seekToEnd(partitions) + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got latest offsets for partition : $partitionOffsets") + partitionOffsets } + /** + * Fetch the earliest offsets for newly discovered partitions. The return result may not contain + * some partitions if they are deleted. + */ + private def fetchNewPartitionEarliestOffsets( + newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = + if (newPartitions.isEmpty) { + Map.empty[TopicPartition, Long] + } else { + withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"\tPartitions assigned to consumer: $partitions") + + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in + // `partitions`. So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") + partitionOffsets + } + } + /** * Helper function that does multiple retries on the a body of code that returns offsets. * Retries are needed to handle transient failures. For e.g. race conditions between getting @@ -284,6 +390,9 @@ private[kafka010] case class KafkaSource( */ private def withRetriesWithoutInterrupt( body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + synchronized { var result: Option[Map[TopicPartition, Long]] = None var attempt = 1 @@ -302,6 +411,8 @@ private[kafka010] case class KafkaSource( try { result = Some(body) } catch { + case x: OffsetOutOfRangeException => + reportDataLoss(x.getMessage) case NonFatal(e) => lastException = e logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) @@ -358,6 +469,17 @@ private[kafka010] object KafkaSource { def createConsumer(): Consumer[Array[Byte], Array[Byte]] } + case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.assign(ju.Arrays.asList(partitions: _*)) + consumer + } + + override def toString: String = s"Assign[${partitions.mkString(", ")}]" + } + case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) extends ConsumerStrategy { override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 1b0a2fe955d03..585ced875caa7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -77,10 +77,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val autoOffsetResetValue = caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { - case Some(value) => value.trim() // same values as those supported by auto.offset.reset - case None => "latest" - } + val startingOffsets = + caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { + case Some("latest") => LatestOffsets + case Some("earliest") => EarliestOffsets + case Some(json) => SpecificOffsets(JsonUtils.partitionOffsets(json)) + case None => LatestOffsets + } val kafkaParamsForStrategy = ConfigUpdater("source", specifiedKafkaParams) @@ -90,8 +93,9 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider // So that consumers in Kafka source do not mess with any existing group id .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") - // So that consumers can start from earliest or latest - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetResetValue) + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") // So that consumers in the driver does not commit offsets unnecessarily .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") @@ -124,6 +128,10 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider .build() val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("assign", value) => + AssignStrategy( + JsonUtils.partitions(value), + kafkaParamsForStrategy) case ("subscribe", value) => SubscribeStrategy( value.split(",").map(_.trim()).filter(_.nonEmpty), @@ -147,6 +155,7 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider kafkaParamsForExecutors, parameters, metadataPath, + startingOffsets, failOnDataLoss) } @@ -168,6 +177,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider } val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("assign", value) => + if (!value.trim.startsWith("{")) { + throw new IllegalArgumentException( + "No topicpartitions to assign as specified value for option " + + s"'assign' is '$value'") + } + case ("subscribe", value) => val topics = value.split(",").map(_.trim).filter(_.nonEmpty) if (topics.isEmpty) { @@ -188,14 +204,6 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider throw new IllegalArgumentException("Unknown option") } - caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { - case Some(pos) if !STARTING_OFFSET_OPTION_VALUES.contains(pos.trim.toLowerCase) => - throw new IllegalArgumentException( - s"Illegal value '$pos' for option '$STARTING_OFFSET_OPTION_KEY', " + - s"acceptable values are: ${STARTING_OFFSET_OPTION_VALUES.mkString(", ")}") - case _ => - } - // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -208,11 +216,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider throw new IllegalArgumentException( s""" |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. - |Instead set the source option '$STARTING_OFFSET_OPTION_KEY' to 'earliest' or 'latest' to - |specify where to start. Structured Streaming manages which offsets are consumed + |Instead set the source option '$STARTING_OFFSETS_OPTION_KEY' to 'earliest' or 'latest' + |to specify where to start. Structured Streaming manages which offsets are consumed |internally, rather than relying on the kafkaConsumer to do it. This will ensure that no |data is missed when when new topics/partitions are dynamically subscribed. Note that - |'$STARTING_OFFSET_OPTION_KEY' only applies when a new Streaming query is started, and + |'$STARTING_OFFSETS_OPTION_KEY' only applies when a new Streaming query is started, and |that resuming will always pick up from where the query left off. See the docs for more |details. """.stripMargin) @@ -275,8 +283,7 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider } private[kafka010] object KafkaSourceProvider { - private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern") - private val STARTING_OFFSET_OPTION_KEY = "startingoffset" - private val STARTING_OFFSET_OPTION_VALUES = Set("earliest", "latest") + private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") + private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 496af7e39abab..802dd040aed93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -112,6 +112,11 @@ private[kafka010] class KafkaSourceRDD( buf.toArray } + override def getPreferredLocations(split: Partition): Seq[String] = { + val part = split.asInstanceOf[KafkaSourceRDDPartition] + part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty) + } + override def compute( thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala new file mode 100644 index 0000000000000..83959e597171a --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +/* + * Values that can be specified for config startingOffsets + */ +private[kafka010] sealed trait StartingOffsets + +private[kafka010] case object EarliestOffsets extends StartingOffsets + +private[kafka010] case object LatestOffsets extends StartingOffsets + +private[kafka010] case class SpecificOffsets( + partitionOffsets: Map[TopicPartition, Long]) extends StartingOffsets diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala new file mode 100644 index 0000000000000..54b980049d1a2 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite + +class JsonUtilsSuite extends SparkFunSuite { + + test("parsing partitions") { + val parsed = JsonUtils.partitions("""{"topicA":[0,1],"topicB":[4,6]}""") + val expected = Array( + new TopicPartition("topicA", 0), + new TopicPartition("topicA", 1), + new TopicPartition("topicB", 4), + new TopicPartition("topicB", 6) + ) + assert(parsed.toSeq === expected.toSeq) + } + + test("parsing partitionOffsets") { + val parsed = JsonUtils.partitionOffsets( + """{"topicA":{"0":23,"1":-1},"topicB":{"0":-2}}""") + + assert(parsed(new TopicPartition("topicA", 0)) === 23) + assert(parsed(new TopicPartition("topicA", 1)) === -1) + assert(parsed(new TopicPartition("topicB", 0)) === -2) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index c640b93b0a2ee..ed4cc75920e8e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -22,13 +22,15 @@ import java.util.concurrent.atomic.AtomicInteger import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.common.TopicPartition +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest } import org.apache.spark.sql.test.SharedSQLContext - abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected var testUtils: KafkaTestUtils = _ @@ -52,7 +54,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, - // we don't know which data should be fetched when `startingOffset` is latest. + // we don't know which data should be fetched when `startingOffsets` is latest. q.processAllAvailable() true } @@ -132,6 +134,72 @@ class KafkaSourceSuite extends KafkaSourceTest { private val topicId = new AtomicInteger(0) + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(newTopic(), partitions = 5) @@ -155,26 +223,52 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("assign from latest offsets") { + val topic = newTopic() + testFromLatestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) + } + + test("assign from earliest offsets") { + val topic = newTopic() + testFromEarliestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) + } + + test("assign from specific offsets") { + val topic = newTopic() + testFromSpecificOffsets(topic, "assign" -> assignString(topic, 0 to 4)) + } + test("subscribing topic by name from latest offsets") { val topic = newTopic() - testFromLatestOffsets(topic, "subscribe" -> topic) + testFromLatestOffsets(topic, true, "subscribe" -> topic) } test("subscribing topic by name from earliest offsets") { val topic = newTopic() - testFromEarliestOffsets(topic, "subscribe" -> topic) + testFromEarliestOffsets(topic, true, "subscribe" -> topic) + } + + test("subscribing topic by name from specific offsets") { + val topic = newTopic() + testFromSpecificOffsets(topic, "subscribe" -> topic) } test("subscribing topic by pattern from latest offsets") { val topicPrefix = newTopic() val topic = topicPrefix + "-suffix" - testFromLatestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + testFromLatestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") } test("subscribing topic by pattern from earliest offsets") { val topicPrefix = newTopic() val topic = topicPrefix + "-suffix" - testFromEarliestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + testFromEarliestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern from specific offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromSpecificOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") } test("subscribing topic by pattern with topic deletions") { @@ -233,6 +327,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testBadOptions("subscribe" -> "t", "subscribePattern" -> "t.*")( "only one", "options can be specified") + testBadOptions("subscribe" -> "t", "assign" -> """{"a":[0]}""")( + "only one", "options can be specified") + + testBadOptions("assign" -> "")("no topicpartitions to assign") testBadOptions("subscribe" -> "")("no topics to subscribe") testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") } @@ -264,9 +362,90 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total").toInt > 0) + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def testFromLatestOffsets(topic: String, options: (String, String)*): Unit = { + private def assignString(topic: String, partitions: Iterable[Int]): String = { + JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) + } + + private def testFromSpecificOffsets(topic: String, options: (String, String)*): Unit = { + val partitionOffsets = Map( + new TopicPartition(topic, 0) -> -2L, + new TopicPartition(topic, 1) -> -1L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 1L, + new TopicPartition(topic, 4) -> 2L + ) + val startingOffsets = JsonUtils.partitionOffsets(partitionOffsets) + + testUtils.createTopic(topic, partitions = 5) + // part 0 starts at earliest, these should all be seen + testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0)) + // part 1 starts at latest, these should all be skipped + testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1)) + // part 2 starts at 0, these should all be seen + testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2)) + // part 3 starts at 1, first should be skipped + testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3)) + // part 4 starts at 2, first and second should be skipped + testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4)) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffsets", startingOffsets) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), + StopStream, + StartStream(), + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), // Should get the data back on recovery + AddKafkaData(Set(topic), 30, 31, 32, 33, 34)(ensureDataInMultiplePartition = true), + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22, 30, 31, 32, 33, 34), + StopStream + ) + } + + private def testFromLatestOffsets( + topic: String, + addPartitions: Boolean, + options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, Array("-1")) require(testUtils.getLatestOffsets(Set(topic)).size === 5) @@ -274,7 +453,7 @@ class KafkaSourceSuite extends KafkaSourceTest { val reader = spark .readStream .format("kafka") - .option("startingOffset", s"latest") + .option("startingOffsets", s"latest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") options.foreach { case (k, v) => reader.option(k, v) } @@ -297,7 +476,9 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - testUtils.addPartitions(topic, 10) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -305,7 +486,10 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } - private def testFromEarliestOffsets(topic: String, options: (String, String)*): Unit = { + private def testFromEarliestOffsets( + topic: String, + addPartitions: Boolean, + options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) require(testUtils.getLatestOffsets(Set(topic)).size === 5) @@ -313,7 +497,7 @@ class KafkaSourceSuite extends KafkaSourceTest { val reader = spark.readStream reader .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) - .option("startingOffset", s"earliest") + .option("startingOffsets", s"earliest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") options.foreach { case (k, v) => reader.option(k, v) } @@ -333,7 +517,9 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - testUtils.addPartitions(topic, 10) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 3eb8a737ba4c8..9b24ccdd560e8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -201,11 +201,23 @@ class KafkaTestUtils extends Logging { /** Send the array of messages to the Kafka broker */ def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = { + sendMessages(topic, messages, None) + } + + /** Send the array of messages to the Kafka broker using specified partition */ + def sendMessages( + topic: String, + messages: Array[String], + partition: Option[Int]): Seq[(String, RecordMetadata)] = { producer = new KafkaProducer[String, String](producerConfiguration) val offsets = try { messages.map { m => + val record = partition match { + case Some(p) => new ProducerRecord[String, String](topic, p, null, m) + case None => new ProducerRecord[String, String](topic, m) + } val metadata = - producer.send(new ProducerRecord[String, String](topic, m)).get(10, TimeUnit.SECONDS) + producer.send(record).get(10, TimeUnit.SECONDS) logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") (m, metadata) } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 432537ebf05b2..7e57bb18cbd50 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -282,13 +282,13 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected def commitAll(): Unit = { val m = new ju.HashMap[TopicPartition, OffsetAndMetadata]() - val it = commitQueue.iterator() - while (it.hasNext) { - val osr = it.next + var osr = commitQueue.poll() + while (null != osr) { val tp = osr.topicPartition val x = m.get(tp) val offset = if (null == x) { osr.untilOffset } else { Math.max(x.offset, osr.untilOffset) } m.put(tp, new OffsetAndMetadata(offset)) + osr = commitQueue.poll() } if (!m.isEmpty) { consumer.commitAsync(m, commitCallback.get) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 4ca19f3387f07..ef3890962494d 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -243,6 +243,24 @@ private[spark] object BLAS extends Serializable { spr(alpha, v, U.values) } + /** + * y := alpha*A*x + beta*y + * + * @param n The order of the n by n matrix A. + * @param A The upper triangular part of A in a [[DenseVector]] (column major). + * @param x The [[DenseVector]] transformed by A. + * @param y The [[DenseVector]] to be modified in place. + */ + def dspmv( + n: Int, + alpha: Double, + A: DenseVector, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + f2jBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1) + } + /** * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. * diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 6e72a5fff0a91..877ac68983348 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -422,4 +422,49 @@ class BLASSuite extends SparkMLFunSuite { assert(dATT.multiply(sx) ~== expected absTol 1e-15) assert(sATT.multiply(sx) ~== expected absTol 1e-15) } + + test("spmv") { + /* + A = [[3.0, -2.0, 2.0, -4.0], + [-2.0, -8.0, 4.0, 7.0], + [2.0, 4.0, -3.0, -3.0], + [-4.0, 7.0, -3.0, 0.0]] + x = [5.0, 2.0, -1.0, -9.0] + Ax = [ 45., -93., 48., -3.] + */ + val A = new DenseVector(Array(3.0, -2.0, -8.0, 2.0, 4.0, -3.0, -4.0, 7.0, -3.0, 0.0)) + val x = new DenseVector(Array(5.0, 2.0, -1.0, -9.0)) + val n = 4 + + val y1 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0)) + val y2 = y1.copy + val y3 = y1.copy + val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + + val expected1 = new DenseVector(Array(42.0, -87.0, 40.0, -6.0)) + val expected2 = new DenseVector(Array(19.5, -40.5, 16.0, -4.5)) + val expected3 = new DenseVector(Array(-25.5, 52.5, -32.0, -1.5)) + val expected4 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0)) + val expected5 = new DenseVector(Array(43.5, -90.0, 44.0, -4.5)) + val expected6 = new DenseVector(Array(46.5, -96.0, 52.0, -1.5)) + val expected7 = new DenseVector(Array(45.0, -93.0, 48.0, -3.0)) + + dspmv(n, 1.0, A, x, 1.0, y1) + dspmv(n, 0.5, A, x, 1.0, y2) + dspmv(n, -0.5, A, x, 1.0, y3) + dspmv(n, 0.0, A, x, 1.0, y4) + dspmv(n, 1.0, A, x, 0.5, y5) + dspmv(n, 1.0, A, x, -0.5, y6) + dspmv(n, 1.0, A, x, 0.0, y7) + assert(y1 ~== expected1 absTol 1e-8) + assert(y2 ~== expected2 absTol 1e-8) + assert(y3 ~== expected3 absTol 1e-8) + assert(y4 ~== expected4 absTol 1e-8) + assert(y5 ~== expected5 absTol 1e-8) + assert(y6 ~== expected6 absTol 1e-8) + assert(y7 ~== expected7 absTol 1e-8) + } } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 2796fcf2cbc22..9c0aa73938478 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -287,7 +287,7 @@ class MatricesSuite extends SparkMLFunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) + val deHorz2 = Matrices.horzcat(Array.empty[Matrix]) assert(deHorz1.numRows === 3) assert(spHorz2.numRows === 3) @@ -341,7 +341,7 @@ class MatricesSuite extends SparkMLFunSuite { val deVert1 = Matrices.vertcat(Array(deMat1, deMat3)) val spVert2 = Matrices.vertcat(Array(spMat1, deMat3)) val spVert3 = Matrices.vertcat(Array(deMat1, spMat3)) - val deVert2 = Matrices.vertcat(Array[Matrix]()) + val deVert2 = Matrices.vertcat(Array.empty[Matrix]) assert(deVert1.numRows === 5) assert(spVert2.numRows === 5) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtilsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtilsSuite.scala index 5cbf2f04e6269..2dc0ee32d5762 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtilsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtilsSuite.scala @@ -110,9 +110,9 @@ class TestingUtilsSuite extends SparkMLFunSuite { assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) assert(Vectors.dense(Array(3.1)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array.empty[Double]) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) // Should throw exception with message when test fails. intercept[TestFailedException]( @@ -125,7 +125,7 @@ class TestingUtilsSuite extends SparkMLFunSuite { Vectors.dense(Array(3.1)) ~== Vectors.dense(Array(3.535, 3.534)) relTol 0.01) intercept[TestFailedException]( - Vectors.dense(Array[Double]()) ~== Vectors.dense(Array(3.135)) relTol 0.01) + Vectors.dense(Array.empty[Double]) ~== Vectors.dense(Array(3.135)) relTol 0.01) // Comparing against zero should fail the test and throw exception with message // saying that the relative error is meaningless in this situation. @@ -145,7 +145,7 @@ class TestingUtilsSuite extends SparkMLFunSuite { assert(Vectors.dense(Array(3.1)) !~== Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~== + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) } @@ -176,14 +176,14 @@ class TestingUtilsSuite extends SparkMLFunSuite { assert(!(Vectors.dense(Array(3.1)) ~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5)) - assert(Vectors.dense(Array[Double]()) !~= + assert(Vectors.dense(Array.empty[Double]) !~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5) - assert(!(Vectors.dense(Array[Double]()) ~= + assert(!(Vectors.dense(Array.empty[Double]) ~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5)) - assert(Vectors.dense(Array[Double]()) ~= - Vectors.dense(Array[Double]()) absTol 1E-5) + assert(Vectors.dense(Array.empty[Double]) ~= + Vectors.dense(Array.empty[Double]) absTol 1E-5) // Should throw exception with message when test fails. intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== @@ -195,7 +195,7 @@ class TestingUtilsSuite extends SparkMLFunSuite { intercept[TestFailedException](Vectors.dense(Array(3.1)) ~== Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7)) absTol 1E-6) - intercept[TestFailedException](Vectors.dense(Array[Double]()) ~== + intercept[TestFailedException](Vectors.dense(Array.empty[Double]) ~== Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7)) absTol 1E-6) // Comparisons of two sparse vectors @@ -214,7 +214,7 @@ class TestingUtilsSuite extends SparkMLFunSuite { assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-6, 2.4)) !~== Vectors.sparse(1, Array(0), Array(3.1)) absTol 1E-3) - assert(Vectors.sparse(0, Array[Int](), Array[Double]()) !~== + assert(Vectors.sparse(0, Array.empty[Int], Array.empty[Double]) !~== Vectors.sparse(1, Array(0), Array(3.1)) absTol 1E-3) // Comparisons of a dense vector and a sparse vector @@ -230,14 +230,14 @@ class TestingUtilsSuite extends SparkMLFunSuite { assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== Vectors.dense(Array(3.1)) absTol 1E-6) - assert(Vectors.dense(Array[Double]()) !~== + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.sparse(3, Array(0, 2), Array(0, 2.4)) absTol 1E-6) assert(Vectors.sparse(1, Array(0), Array(3.1)) !~== Vectors.dense(Array(3.1, 3.2)) absTol 1E-6) assert(Vectors.dense(Array(3.1)) !~== - Vectors.sparse(0, Array[Int](), Array[Double]()) absTol 1E-6) + Vectors.sparse(0, Array.empty[Int], Array.empty[Double]) absTol 1E-6) } test("Comparing Matrices using absolute error.") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ba70293273f94..8bffe0cda0327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -137,9 +137,17 @@ class GBTClassifier @Since("1.4.0") ( } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + instr.logNumClasses(2) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 862a468745fbd..8fdaae04c42ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -622,7 +622,7 @@ class LogisticRegression @Since("1.2.0") ( rawCoefficients(coefIndex) } } else { - Array[Double]() + Array.empty[Double] } val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) { // The intercepts are never regularized, so we always center the mean. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index a97bd0fb16fd7..2718dd93dcb5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) + + private var trainingSummary: Option[BisectingKMeansSummary] = None + + private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.1.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.1.0") + def summary: BisectingKMeansSummary = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { @@ -228,14 +252,21 @@ class BisectingKMeans @Since("2.0.0") ( case Row(point: Vector) => OldVectors.fromML(point) } + val instr = Instrumentation.create(this, rdd) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val bkm = new MLlibBisectingKMeans() .setK($(k)) .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) val parentModel = bkm.run(rdd) - val model = new BisectingKMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) + val summary = new BisectingKMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(summary) + instr.logSuccess(model) + model } @Since("2.0.0") @@ -251,3 +282,21 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { @Since("2.0.0") override def load(path: String): BisectingKMeans = super.load(path) } + + +/** + * :: Experimental :: + * Summary of BisectingKMeans. + * + * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. + */ +@Since("2.1.0") +@Experimental +class BisectingKMeansSummary private[clustering] ( + predictions: DataFrame, + predictionCol: String, + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala new file mode 100644 index 0000000000000..8b5f525194f28 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{DataFrame, Row} + +/** + * :: Experimental :: + * Summary of clustering algorithms. + * + * @param predictions [[DataFrame]] produced by model.transform(). + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. + */ +@Experimental +class ClusteringSummary private[clustering] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val featuresCol: String, + val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of (number of data points in) each cluster. + */ + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 69f060ad7711e..8fac63fefbb55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -323,6 +323,9 @@ class GaussianMixture @Since("2.0.0") ( case Row(point: Vector) => OldVectors.fromML(point) } + val instr = Instrumentation.create(this, rdd) + instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) + val algo = new MLlibGM() .setK($(k)) .setMaxIterations($(maxIter)) @@ -337,6 +340,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) model.setSummary(summary) + instr.logNumFeatures(model.gaussians.head.mean.size) + instr.logSuccess(model) + model } @Since("2.0.0") @@ -356,42 +362,25 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { * :: Experimental :: * Summary of GaussianMixture. * - * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]] - * @param predictionCol Name for column of predicted clusters in `predictions` - * @param probabilityCol Name for column of predicted probability of each cluster in `predictions` - * @param featuresCol Name for column of features in `predictions` - * @param k Number of clusters + * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param probabilityCol Name for column of predicted probability of each cluster + * in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. */ @Since("2.0.0") @Experimental class GaussianMixtureSummary private[clustering] ( - @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, + predictions: DataFrame, + predictionCol: String, @Since("2.0.0") val probabilityCol: String, - @Since("2.0.0") val featuresCol: String, - @Since("2.0.0") val k: Int) extends Serializable { - - /** - * Cluster centers of the transformed data. - */ - @Since("2.0.0") - @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { /** * Probability of each cluster. */ @Since("2.0.0") @transient lazy val probability: DataFrame = predictions.select(probabilityCol) - - /** - * Size of (number of data points in) each cluster. - */ - @Since("2.0.0") - lazy val clusterSizes: Array[Long] = { - val sizes = Array.fill[Long](k)(0) - cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { - case Row(cluster: Int, count: Long) => sizes(cluster) = count - } - sizes - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b04e82838e714..85bb8c93b3fa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -41,7 +41,9 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe with HasSeed with HasPredictionCol with HasTol { /** - * The number of clusters to create (k). Must be > 1. Default: 2. + * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than + * k clusters to be returned, for example, if there are fewer than k distinct points to cluster. + * Default: 2. * @group param */ @Since("1.5.0") @@ -324,9 +326,9 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - val m = model.setSummary(summary) - instr.logSuccess(m) - m + model.setSummary(summary) + instr.logSuccess(model) + model } @Since("1.5.0") @@ -346,35 +348,15 @@ object KMeans extends DefaultParamsReadable[KMeans] { * :: Experimental :: * Summary of KMeans. * - * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] - * @param predictionCol Name for column of predicted clusters in `predictions` - * @param featuresCol Name for column of features in `predictions` - * @param k Number of clusters + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. */ @Since("2.0.0") @Experimental class KMeansSummary private[clustering] ( - @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val featuresCol: String, - @Since("2.0.0") val k: Int) extends Serializable { - - /** - * Cluster centers of the transformed data. - */ - @Since("2.0.0") - @transient lazy val cluster: DataFrame = predictions.select(predictionCol) - - /** - * Size of (number of data points in) each cluster. - */ - @Since("2.0.0") - lazy val clusterSizes: Array[Long] = { - val sizes = Array.fill[Long](k)(0) - cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { - case Row(cluster: Int, count: Long) => sizes(cluster) = count - } - sizes - } - -} + predictions: DataFrame, + predictionCol: String, + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e1b1..1143f0f565ebd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val (filteredDataset, keepInvalid) = { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { + // "skip" NaN option is set, will filter out NaN values in the dataset + (dataset.na.drop().toDF(), false) + } else { + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) + } + } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -126,11 +168,26 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { if (feature.isNaN) { - splits.length - 1 + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala new file mode 100644 index 0000000000000..333a8c364a884 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[LSH]]. + */ +private[ml] trait LSHParams extends HasInputCol with HasOutputCol { + /** + * Param for the dimension of LSH OR-amplification. + * + * In this implementation, we use LSH OR-amplification to reduce the false negative rate. The + * higher the dimension is, the lower the false negative rate. + * @group param + */ + final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + + "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " improves the running performance", ParamValidators.gt(0)) + + /** @group getParam */ + final def getOutputDim: Int = $(outputDim) + + setDefault(outputDim -> 1) + + /** + * Transform the Schema for LSH + * @param schema The schema of the input dataset without [[outputCol]] + * @return A derived schema with [[outputCol]] added + */ + protected[this] final def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } +} + +/** + * Model produced by [[LSH]]. + */ +private[ml] abstract class LSHModel[T <: LSHModel[T]] + extends Model[T] with LSHParams with MLWritable { + self: T => + + /** + * The hash function of LSH, mapping a predefined KeyType to a Vector + * @return The mapping of LSH function. + */ + protected[ml] val hashFunction: Vector => Vector + + /** + * Calculate the distance between two different keys using the distance metric corresponding + * to the hashFunction + * @param x One input vector in the metric space + * @param y One input vector in the metric space + * @return The distance between x and y + */ + protected[ml] def keyDistance(x: Vector, y: Vector): Double + + /** + * Calculate the distance between two different hash Vectors. + * + * @param x One of the hash vector + * @param y Another hash vector + * @return The distance between hash vectors x and y + */ + protected[ml] def hashDistance(x: Vector, y: Vector): Double + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + val transformUDF = udf(hashFunction, new VectorUDT) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Given a large dataset and an item, approximately find at most k items which have the closest + * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if + * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the + * transformed data when necessary. + * + * This method implements two ways of fetching k nearest neighbors: + * - Single Probing: Fast, return at most k elements (Probing only one buckets) + * - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key) + * + * @param dataset the dataset to search for nearest neighbors of the key + * @param key Feature vector representing the item to search for + * @param numNearestNeighbors The maximum number of nearest neighbors + * @param singleProbing True for using Single Probing; false for multiple probing + * @param distCol Output column for storing the distance between each result row and the key + * @return A dataset containing at most k items closest to the key. A distCol is added to show + * the distance between each row and the key. + */ + def approxNearestNeighbors( + dataset: Dataset[_], + key: Vector, + numNearestNeighbors: Int, + singleProbing: Boolean, + distCol: String): Dataset[_] = { + require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") + // Get Hash Value of the key + val keyHash = hashFunction(key) + val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + transform(dataset) + } else { + dataset.toDF() + } + + // In the origin dataset, find the hash value that is closest to the key + val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType) + val hashDistCol = hashDistUDF(col($(outputCol))) + + val modelSubset = if (singleProbing) { + modelDataset.filter(hashDistCol === 0.0) + } else { + // Compute threshold to get exact k elements. + val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors) + val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol)) + val hashThreshold = thresholdDataset.take(1).head.getDouble(0) + + // Filter the dataset where the hash value is less than the threshold. + modelDataset.filter(hashDistCol <= hashThreshold) + } + + // Get the top k nearest neighbor by their distance to the key + val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType) + val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol)))) + modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors) + } + + /** + * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search + * nearest neighbors and "distCol" as default distCol. + */ + def approxNearestNeighbors( + dataset: Dataset[_], + key: Vector, + numNearestNeighbors: Int): Dataset[_] = { + approxNearestNeighbors(dataset, key, numNearestNeighbors, true, "distCol") + } + + /** + * Preprocess step for approximate similarity join. Transform and explode the [[outputCol]] to + * two explodeCols: entry and value. "entry" is the index in hash vector, and "value" is the + * value of corresponding value of the index in the vector. + * + * @param dataset The dataset to transform and explode. + * @param explodeCols The alias for the exploded columns, must be a seq of two strings. + * @return A dataset containing idCol, inputCol and explodeCols + */ + private[this] def processDataset( + dataset: Dataset[_], + inputName: String, + explodeCols: Seq[String]): Dataset[_] = { + require(explodeCols.size == 2, "explodeCols must be two strings.") + val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap, + MapType(DataTypes.IntegerType, DataTypes.DoubleType)) + val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + transform(dataset) + } else { + dataset.toDF() + } + modelDataset.select( + struct(col("*")).as(inputName), + explode(vectorToMap(col($(outputCol)))).as(explodeCols)) + } + + /** + * Recreate a column using the same column name but different attribute id. Used in approximate + * similarity join. + * @param dataset The dataset where a column need to recreate + * @param colName The name of the column to recreate + * @param tmpColName A temporary column name which does not conflict with existing columns + * @return + */ + private[this] def recreateCol( + dataset: Dataset[_], + colName: String, + tmpColName: String): Dataset[_] = { + dataset + .withColumnRenamed(colName, tmpColName) + .withColumn(colName, col(tmpColName)) + .drop(tmpColName) + } + + /** + * Join two dataset to approximately find all pairs of rows whose distance are smaller than + * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the + * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed + * data when necessary. + * + * @param datasetA One of the datasets to join + * @param datasetB Another dataset to join + * @param threshold The threshold for the distance of row pairs + * @param distCol Output column for storing the distance between each result row and the key + * @return A joined dataset containing pairs of rows. The original rows are in columns + * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair + */ + def approxSimilarityJoin( + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double, + distCol: String): Dataset[_] = { + + val leftColName = "datasetA" + val rightColName = "datasetB" + val explodeCols = Seq("entry", "hashValue") + val explodedA = processDataset(datasetA, leftColName, explodeCols) + + // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity. + // TODO: Remove recreateCol logic once SPARK-17154 is resolved. + val explodedB = if (datasetA != datasetB) { + processDataset(datasetB, rightColName, explodeCols) + } else { + val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}") + processDataset(recreatedB, rightColName, explodeCols) + } + + // Do a hash join on where the exploded hash values are equal. + val joinedDataset = explodedA.join(explodedB, explodeCols) + .drop(explodeCols: _*).distinct() + + // Add a new column to store the distance of the two rows. + val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType) + val joinedDatasetWithDist = joinedDataset.select(col("*"), + distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol) + ) + + // Filter the joined datasets where the distance are smaller than the threshold. + joinedDatasetWithDist.filter(col(distCol) < threshold) + } + + /** + * Overloaded method for approxSimilarityJoin. Use "distCol" as default distCol. + */ + def approxSimilarityJoin( + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double): Dataset[_] = { + approxSimilarityJoin(datasetA, datasetB, threshold, "distCol") + } +} + +/** + * Locality Sensitive Hashing for different metrics space. Support basic transformation with a new + * hash column, approximate nearest neighbor search with a dataset and a key, and approximate + * similarity join of two datasets. + * + * This LSH class implements OR-amplification: more than 1 hash functions can be chosen, and each + * input vector are hashed by all hash functions. Two input vectors are defined to be in the same + * bucket as long as ANY one of the hash value matches. + * + * References: + * (1) Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions + * via hashing." VLDB 7 Sep. 1999: 518-529. + * (2) Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint + * arXiv:1408.2927 (2014). + */ +private[ml] abstract class LSH[T <: LSHModel[T]] + extends Estimator[T] with LSHParams with DefaultParamsWritable { + self: Estimator[T] => + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setOutputDim(value: Int): this.type = set(outputDim, value) + + /** + * Validate and create a new instance of concrete LSHModel. Because different LSHModel may have + * different initial setting, developer needs to define how their LSHModel is created instead of + * using reflection in this abstract class. + * @param inputDim The dimension of the input dataset + * @return A new LSHModel instance without any params + */ + protected[this] def createRawLSHModel(inputDim: Int): T + + override def fit(dataset: Dataset[_]): T = { + transformSchema(dataset.schema, logging = true) + val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size + val model = createRawLSHModel(inputDim).setParent(this) + copyValues(model) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala new file mode 100644 index 0000000000000..d9d0f32254e24 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * + * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is + * a perfect hash function: + * `h_i(x) = (x * k_i mod prime) mod numEntries` + * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*` + * + * Reference: + * [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]] + * + * @param numEntries The number of entries of the hash functions. + * @param randCoefficients An array of random coefficients, each used by one hash function. + */ +@Experimental +@Since("2.1.0") +class MinHashModel private[ml] ( + override val uid: String, + @Since("2.1.0") val numEntries: Int, + @Since("2.1.0") val randCoefficients: Array[Int]) + extends LSHModel[MinHashModel] { + + @Since("2.1.0") + override protected[ml] val hashFunction: Vector => Vector = { + elems: Vector => + require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") + val elemsList = elems.toSparse.indices.toList + val hashValues = randCoefficients.map({ randCoefficient: Int => + elemsList.map({elem: Int => + (1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries + }).min.toDouble + }) + Vectors.dense(hashValues) + } + + @Since("2.1.0") + override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { + val xSet = x.toSparse.indices.toSet + val ySet = y.toSparse.indices.toSet + val intersectionSize = xSet.intersect(ySet).size.toDouble + val unionSize = xSet.size + ySet.size - intersectionSize + assert(unionSize > 0, "The union of two input sets must have at least 1 elements") + 1 - intersectionSize / unionSize + } + + @Since("2.1.0") + override protected[ml] def hashDistance(x: Vector, y: Vector): Double = { + // Since it's generated by hashing, it will be a pair of dense vectors. + x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) + + @Since("2.1.0") + override def write: MLWriter = new MinHashModel.MinHashModelWriter(this) +} + +/** + * :: Experimental :: + * + * LSH class for Jaccard distance. + * + * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example, + * `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` + * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5. + * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated + * as binary "1" values. + * + * References: + * [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]] + */ +@Experimental +@Since("2.1.0") +class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed { + + + @Since("2.1.0") + override def setInputCol(value: String): this.type = super.setInputCol(value) + + @Since("2.1.0") + override def setOutputCol(value: String): this.type = super.setOutputCol(value) + + @Since("2.1.0") + override def setOutputDim(value: Int): this.type = super.setOutputDim(value) + + @Since("2.1.0") + def this() = { + this(Identifiable.randomUID("min hash")) + } + + /** @group setParam */ + @Since("2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.1.0") + override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = { + require(inputDim <= MinHash.prime / 2, + s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.") + val rand = new Random($(seed)) + val numEntry = inputDim * 2 + val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1)) + new MinHashModel(uid, numEntry, randCoofs) + } + + @Since("2.1.0") + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + validateAndTransformSchema(schema) + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +@Since("2.1.0") +object MinHash extends DefaultParamsReadable[MinHash] { + // A large prime smaller than sqrt(2^63 − 1) + private[ml] val prime = 2038074743 + + @Since("2.1.0") + override def load(path: String): MinHash = super.load(path) +} + +@Since("2.1.0") +object MinHashModel extends MLReadable[MinHashModel] { + + @Since("2.1.0") + override def read: MLReader[MinHashModel] = new MinHashModelReader + + @Since("2.1.0") + override def load(path: String): MinHashModel = super.load(path) + + private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter { + + private case class Data(numEntries: Int, randCoefficients: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numEntries, instance.randCoefficients) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinHashModelReader extends MLReader[MinHashModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[MinHashModel].getName + + override def load(path: String): MinHashModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head() + val numEntries = data.getAs[Int](0) + val randCoefficients = data.getAs[Seq[Int]](1).toArray + val model = new MinHashModel(metadata.uid, numEntries, randCoefficients) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f6a3..b9e01dde70d85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala new file mode 100644 index 0000000000000..1b524c6710b42 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import breeze.linalg.normalize +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * + * Params for [[RandomProjection]]. + */ +private[ml] trait RandomProjectionParams extends Params { + + /** + * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of + * buckets will be `(max L2 norm of input vectors) / bucketLength`. + * + * + * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a + * reasonable value + * @group param + */ + val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength", + "the length of each hash bucket, a larger bucket lowers the false negative rate.", + ParamValidators.gt(0)) + + /** @group getParam */ + final def getBucketLength: Double = $(bucketLength) +} + +/** + * :: Experimental :: + * + * Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors + * are normalized to be unit vectors and each vector is used in a hash function: + * `h_i(x) = floor(r_i.dot(x) / bucketLength)` + * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input + * vectors) / bucketLength`. + * + * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function. + */ +@Experimental +@Since("2.1.0") +class RandomProjectionModel private[ml] ( + override val uid: String, + @Since("2.1.0") val randUnitVectors: Array[Vector]) + extends LSHModel[RandomProjectionModel] with RandomProjectionParams { + + @Since("2.1.0") + override protected[ml] val hashFunction: (Vector) => Vector = { + key: Vector => { + val hashValues: Array[Double] = randUnitVectors.map({ + randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) + }) + Vectors.dense(hashValues) + } + } + + @Since("2.1.0") + override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { + Math.sqrt(Vectors.sqdist(x, y)) + } + + @Since("2.1.0") + override protected[ml] def hashDistance(x: Vector, y: Vector): Double = { + // Since it's generated by hashing, it will be a pair of dense vectors. + x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) + + @Since("2.1.0") + override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this) +} + +/** + * :: Experimental :: + * + * This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean + * distance metrics. + * + * The input is dense or sparse vectors, each of which represents a point in the Euclidean + * distance space. The output will be vectors of configurable dimension. Hash value in the same + * dimension is calculated by the same hash function. + * + * References: + * + * 1. [[https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions + * Wikipedia on Stable Distributions]] + * + * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint + * arXiv:1408.2927 (2014). + */ +@Experimental +@Since("2.1.0") +class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel] + with RandomProjectionParams with HasSeed { + + @Since("2.1.0") + override def setInputCol(value: String): this.type = super.setInputCol(value) + + @Since("2.1.0") + override def setOutputCol(value: String): this.type = super.setOutputCol(value) + + @Since("2.1.0") + override def setOutputDim(value: Int): this.type = super.setOutputDim(value) + + @Since("2.1.0") + def this() = { + this(Identifiable.randomUID("random projection")) + } + + /** @group setParam */ + @Since("2.1.0") + def setBucketLength(value: Double): this.type = set(bucketLength, value) + + /** @group setParam */ + @Since("2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.1.0") + override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = { + val rand = new Random($(seed)) + val randUnitVectors: Array[Vector] = { + Array.fill($(outputDim)) { + val randArray = Array.fill(inputDim)(rand.nextGaussian()) + Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray))) + } + } + new RandomProjectionModel(uid, randUnitVectors) + } + + @Since("2.1.0") + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + validateAndTransformSchema(schema) + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +@Since("2.1.0") +object RandomProjection extends DefaultParamsReadable[RandomProjection] { + + @Since("2.1.0") + override def load(path: String): RandomProjection = super.load(path) +} + +@Since("2.1.0") +object RandomProjectionModel extends MLReadable[RandomProjectionModel] { + + @Since("2.1.0") + override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader + + @Since("2.1.0") + override def load(path: String): RandomProjectionModel = super.load(path) + + private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel) + extends MLWriter { + + // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved. + private case class Data(randUnitVectors: Matrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val numRows = instance.randUnitVectors.length + require(numRows > 0) + val numCols = instance.randUnitVectors.head.size + val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _)) + val randMatrix = Matrices.dense(numRows, numCols, values) + val data = Data(randMatrix) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomProjectionModel].getName + + override def load(path: String): RandomProjectionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors") + .select("randUnitVectors") + .head() + val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 259be2679ce19..b25fff973c441 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -67,7 +67,9 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) val tableName = Identifiable.randomUID(uid) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - dataset.sparkSession.sql(realStatement) + val result = dataset.sparkSession.sql(realStatement) + dataset.sparkSession.catalog.dropTempView(tableName) + result } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index d732f53029e8c..8a6b862cda170 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -81,8 +81,8 @@ private[ml] class IterativelyReweightedLeastSquares( } // Estimate new model - model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, - standardizeLabel = false).fit(newInstances) + model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) // Check convergence val oldCoefficients = oldModel.coefficients diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala new file mode 100644 index 0000000000000..2f5299b010223 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import scala.collection.mutable + +import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.linalg.CholeskyDecomposition + +/** + * A class to hold the solution to the normal equations A^T^ W A x = A^T^ W b. + * + * @param coefficients The least squares coefficients. The last element in the coefficients + * is the intercept when bias is added to A. + * @param aaInv An option containing the upper triangular part of (A^T^ W A)^-1^, in column major + * format. None when an optimization program is used to solve the normal equations. + * @param objectiveHistory Option containing the objective history when an optimization program is + * used to solve the normal equations. None when an analytic solver is used. + */ +private[ml] class NormalEquationSolution( + val coefficients: Array[Double], + val aaInv: Option[Array[Double]], + val objectiveHistory: Option[Array[Double]]) + +/** + * Interface for classes that solve the normal equations locally. + */ +private[ml] sealed trait NormalEquationSolver { + + /** Solve the normal equations from summary statistics. */ + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution +} + +/** + * A class that solves the normal equations directly, using Cholesky decomposition. + */ +private[ml] class CholeskySolver extends NormalEquationSolver { + + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution = { + val k = abBar.size + val x = CholeskyDecomposition.solve(aaBar.values, abBar.values) + val aaInv = CholeskyDecomposition.inverse(aaBar.values, k) + + new NormalEquationSolution(x, Some(aaInv), None) + } +} + +/** + * A class for solving the normal equations using Quasi-Newton optimization methods. + */ +private[ml] class QuasiNewtonSolver( + fitIntercept: Boolean, + maxIter: Int, + tol: Double, + l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver { + + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution = { + val numFeatures = aBar.size + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures + val initialCoefficientsWithIntercept = new Array[Double](numFeaturesPlusIntercept) + if (fitIntercept) { + initialCoefficientsWithIntercept(numFeaturesPlusIntercept - 1) = bBar + } + + val costFun = + new NormalEquationCostFun(bBar, bbBar, abBar, aaBar, aBar, fitIntercept, numFeatures) + val optimizer = l1RegFunc.map { func => + new BreezeOWLQN[Int, BDV[Double]](maxIter, 10, func, tol) + }.getOrElse(new BreezeLBFGS[BDV[Double]](maxIter, 10, tol)) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + new BDV[Double](initialCoefficientsWithIntercept)) + + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + val x = state.x.toArray.clone() + new NormalEquationSolution(x, None, Some(arrayBuilder.result())) + } + + /** + * NormalEquationCostFun implements Breeze's DiffFunction[T] for the normal equation. + * It returns the loss and gradient with L2 regularization at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ + private class NormalEquationCostFun( + bBar: Double, + bbBar: Double, + ab: DenseVector, + aa: DenseVector, + aBar: DenseVector, + fitIntercept: Boolean, + numFeatures: Int) extends DiffFunction[BDV[Double]] { + + private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coef = Vectors.fromBreeze(coefficients).toDense + if (fitIntercept) { + var j = 0 + var dotProd = 0.0 + val coefValues = coef.values + val aBarValues = aBar.values + while (j < numFeatures) { + dotProd += coefValues(j) * aBarValues(j) + j += 1 + } + coefValues(numFeatures) = bBar - dotProd + } + val aax = new DenseVector(new Array[Double](numFeaturesPlusIntercept)) + BLAS.dspmv(numFeaturesPlusIntercept, 1.0, aa, coef, 1.0, aax) + // loss = 1/2 (b^T W b - 2 x^T A^T W b + x^T A^T W A x) + val loss = 0.5 * bbBar - BLAS.dot(ab, coef) + 0.5 * BLAS.dot(coef, aax) + // gradient = A^T W A x - A^T W b + BLAS.axpy(-1.0, ab, aax) + (loss, aax.asBreeze.toDenseVector) + } + } +} + +/** + * Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible + * (singular). + */ +class SingularMatrixException(message: String, cause: Throwable) + extends IllegalArgumentException(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 8f5f4427e1f4b..90c24e1b590ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -20,19 +20,21 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ -import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.rdd.RDD /** * Model fitted by [[WeightedLeastSquares]]. + * * @param coefficients model coefficients * @param intercept model intercept * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable { + val diagInvAtWA: DenseVector, + val objectiveHistory: Array[Double]) extends Serializable { def predict(features: Vector): Double = { BLAS.dot(coefficients, features) + intercept @@ -44,35 +46,52 @@ private[ml] class WeightedLeastSquaresModel( * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares * formulation: * - * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i - * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,, + * + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^ + * + alpha sum,,j,, abs(sigma,,j,, x,,j,,)), * - * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by - * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter, + * and delta and sigma,,j,, are controlled by [[standardizeLabel]] and [[standardizeFeatures]], + * respectively. * * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to * match R's `lm`. * Turn on [[standardizeLabel]] to match R's `glmnet`. * + * @note The coefficients and intercept are always trained in the scaled space, but are returned + * on the original scale. [[standardizeFeatures]] and [[standardizeLabel]] can be used to + * control whether regularization is applied in the original space or the scaled space. * @param fitIntercept whether to fit intercept. If false, z is 0.0. - * @param regParam L2 regularization parameter (lambda) - * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * @param regParam Regularization parameter (lambda). + * @param elasticNetParam the ElasticNet mixing parameter (alpha). + * @param standardizeFeatures whether to standardize features. If true, sigma,,j,, is the * population standard deviation of the j-th column of A. Otherwise, * sigma,,j,, is 1.0. * @param standardizeLabel whether to standardize label. If true, delta is the population standard * deviation of the label column b. Otherwise, delta is 1.0. + * @param solverType the type of solver to use for optimization. + * @param maxIter maximum number of iterations. Only for QuasiNewton solverType. + * @param tol the convergence tolerance of the iterations. Only for QuasiNewton solverType. */ private[ml] class WeightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, + val elasticNetParam: Double, val standardizeFeatures: Boolean, - val standardizeLabel: Boolean) extends Logging with Serializable { + val standardizeLabel: Boolean, + val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, + val maxIter: Int = 100, + val tol: Double = 1e-6) extends Logging with Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") if (regParam == 0.0) { logWarning("regParam is zero, which might cause numerical instability and overfitting.") } + require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, + s"elasticNetParam must be in [0, 1]: $elasticNetParam") + require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") + require(tol > 0, s"tol must be greater than zero: $tol") /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. @@ -82,76 +101,220 @@ private[ml] class WeightedLeastSquares( summary.validate() logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k + val numFeatures = summary.k val triK = summary.triK val wSum = summary.wSum - val bBar = summary.bBar - val bStd = summary.bStd - val aBar = summary.aBar - val aVar = summary.aVar - val abBar = summary.abBar - val aaBar = summary.aaBar - val aaValues = aaBar.values - - if (bStd == 0) { - if (fitIntercept) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - val coefficients = new DenseVector(Array.ofDim(k-1)) - val intercept = bBar + + val rawBStd = summary.bStd + val rawBBar = summary.bBar + // if b is constant (rawBStd is zero), then b cannot be scaled. In this case + // setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm. + val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd + + if (rawBStd == 0) { + if (fitIntercept || rawBBar == 0.0) { + if (rawBBar == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + val coefficients = new DenseVector(Array.ofDim(numFeatures)) + val intercept = rawBBar val diagInvAtWA = new DenseVector(Array(0D)) - return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D)) } else { - require(!(regParam > 0.0 && standardizeLabel), - "The standard deviation of the label is zero. " + - "Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. " + - "Consider setting fitIntercept=true.") + require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + + "zero. Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. Consider setting " + + s"fitIntercept=true.") + } + } + + val bBar = summary.bBar / bStd + val bbBar = summary.bbBar / (bStd * bStd) + + val aStd = summary.aStd + val aStdValues = aStd.values + + val aBar = { + val _aBar = summary.aBar + val _aBarValues = _aBar.values + var i = 0 + // scale aBar to standardized space in-place + while (i < numFeatures) { + if (aStdValues(i) == 0.0) { + _aBarValues(i) = 0.0 + } else { + _aBarValues(i) /= aStdValues(i) + } + i += 1 + } + _aBar + } + val aBarValues = aBar.values + + val abBar = { + val _abBar = summary.abBar + val _abBarValues = _abBar.values + var i = 0 + // scale abBar to standardized space in-place + while (i < numFeatures) { + if (aStdValues(i) == 0.0) { + _abBarValues(i) = 0.0 + } else { + _abBarValues(i) /= (aStdValues(i) * bStd) + } + i += 1 + } + _abBar + } + val abBarValues = abBar.values + + val aaBar = { + val _aaBar = summary.aaBar + val _aaBarValues = _aaBar.values + var j = 0 + var p = 0 + // scale aaBar to standardized space in-place + while (j < numFeatures) { + val aStdJ = aStdValues(j) + var i = 0 + while (i <= j) { + val aStdI = aStdValues(i) + if (aStdJ == 0.0 || aStdI == 0.0) { + _aaBarValues(p) = 0.0 + } else { + _aaBarValues(p) /= (aStdI * aStdJ) + } + p += 1 + i += 1 + } + j += 1 } + _aaBar } + val aaBarValues = aaBar.values - // add regularization to diagonals + val effectiveRegParam = regParam / bStd + val effectiveL1RegParam = elasticNetParam * effectiveRegParam + val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam + + // add L2 regularization to diagonals var i = 0 var j = 2 while (i < triK) { - var lambda = regParam - if (standardizeFeatures) { - lambda *= aVar(j - 2) + var lambda = effectiveL2RegParam + if (!standardizeFeatures) { + val std = aStdValues(j - 2) + if (std != 0.0) { + lambda /= (std * std) + } else { + lambda = 0.0 + } } - if (standardizeLabel && bStd != 0) { - lambda /= bStd + if (!standardizeLabel) { + lambda *= bStd } - aaValues(i) += lambda + aaBarValues(i) += lambda i += j j += 1 } - val aa = if (fitIntercept) { - Array.concat(aaBar.values, aBar.values, Array(1.0)) + val aa = getAtA(aaBarValues, aBarValues) + val ab = getAtB(abBarValues, bBar) + + val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 && + regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) { + val effectiveL1RegFun: Option[(Int) => Double] = if (effectiveL1RegParam != 0.0) { + Some((index: Int) => { + if (fitIntercept && index == numFeatures) { + 0.0 + } else { + if (standardizeFeatures) { + effectiveL1RegParam + } else { + if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0 + } + } + }) + } else { + None + } + new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun) } else { - aaBar.values + new CholeskySolver } - val ab = if (fitIntercept) { - Array.concat(abBar.values, Array(bBar)) - } else { - abBar.values + + val solution = solver match { + case cholesky: CholeskySolver => + try { + cholesky.solve(bBar, bbBar, ab, aa, aBar) + } catch { + // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to + // Quasi-Newton solver. + case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => + logWarning("Cholesky solver failed due to singular covariance matrix. " + + "Retrying with Quasi-Newton solver.") + // ab and aa were modified in place, so reconstruct them + val _aa = getAtA(aaBarValues, aBarValues) + val _ab = getAtB(abBarValues, bBar) + val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None) + newSolver.solve(bBar, bbBar, _ab, _aa, aBar) + } + case qn: QuasiNewtonSolver => + qn.solve(bBar, bbBar, ab, aa, aBar) } - val x = CholeskyDecomposition.solve(aa, ab) + val (coefficientArray, intercept) = if (fitIntercept) { + (solution.coefficients.slice(0, solution.coefficients.length - 1), + solution.coefficients.last * bStd) + } else { + (solution.coefficients, 0.0) + } - val aaInv = CholeskyDecomposition.inverse(aa, k) + // convert the coefficients from the scaled space to the original space + var q = 0 + val len = coefficientArray.length + while (q < len) { + coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 } + q += 1 + } // aaInv is a packed upper triangular matrix, here we get all elements on diagonal - val diagInvAtWA = new DenseVector((1 to k).map { i => - aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) + val diagInvAtWA = solution.aaInv.map { inv => + new DenseVector((1 to k).map { i => + val multiplier = if (i == k && fitIntercept) { + 1.0 + } else { + aStdValues(i - 1) * aStdValues(i - 1) + } + inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier) + }.toArray) + }.getOrElse(new DenseVector(Array(0D))) + + new WeightedLeastSquaresModel(new DenseVector(coefficientArray), intercept, diagInvAtWA, + solution.objectiveHistory.getOrElse(Array(0D))) + } - val (coefficients, intercept) = if (fitIntercept) { - (new DenseVector(x.slice(0, x.length - 1)), x.last) + /** Construct A^T^ A (append bias if necessary). */ + private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = { + if (fitIntercept) { + new DenseVector(Array.concat(aaBar, aBar, Array(1.0))) } else { - (new DenseVector(x), 0.0) + new DenseVector(aaBar.clone()) } + } - new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + /** Construct A^T^ b (append bias if necessary). */ + private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = { + if (fitIntercept) { + new DenseVector(Array.concat(abBar, Array(bBar))) + } else { + new DenseVector(abBar.clone()) + } } } @@ -163,6 +326,13 @@ private[ml] object WeightedLeastSquares { */ val MAX_NUM_FEATURES: Int = 4096 + sealed trait Solver + case object Auto extends Solver + case object Cholesky extends Solver + case object QuasiNewton extends Solver + + val supportedSolvers = Array(Auto, Cholesky, QuasiNewton) + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -262,6 +432,11 @@ private[ml] object WeightedLeastSquares { */ def bBar: Double = bSum / wSum + /** + * Weighted mean of squared labels. + */ + def bbBar: Double = bbSum / wSum + /** * Weighted population standard deviation of labels. */ @@ -285,6 +460,24 @@ private[ml] object WeightedLeastSquares { output } + /** + * Weighted population standard deviation of features. + */ + def aStd: DenseVector = { + val std = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + std(l) = math.sqrt(aaValues(i) / wSum - aw * aw) + i += j + j += 1 + } + new DenseVector(std) + } + /** * Weighted population variance of features. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala new file mode 100644 index 0000000000000..9b352c9863114 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class LogisticRegressionWrapper private ( + val pipeline: PipelineModel, + val features: Array[String], + val isLoaded: Boolean = false) extends MLWritable { + + private val logisticRegressionModel: LogisticRegressionModel = + pipeline.stages(1).asInstanceOf[LogisticRegressionModel] + + lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations + + lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory + + lazy val blrSummary = + logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + lazy val roc: DataFrame = blrSummary.roc + + lazy val areaUnderROC: Double = blrSummary.areaUnderROC + + lazy val pr: DataFrame = blrSummary.pr + + lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold + + lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold + + lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(logisticRegressionModel.getFeaturesCol) + } + + override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this) +} + +private[r] object LogisticRegressionWrapper + extends MLReadable[LogisticRegressionWrapper] { + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + regParam: Double, + elasticNetParam: Double, + maxIter: Int, + tol: Double, + fitIntercept: Boolean, + family: String, + standardization: Boolean, + thresholds: Array[Double], + weightCol: String, + aggregationDepth: Int, + probability: String + ): LogisticRegressionWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val logisticRegression = new LogisticRegression() + .setRegParam(regParam) + .setElasticNetParam(elasticNetParam) + .setMaxIter(maxIter) + .setTol(tol) + .setFitIntercept(fitIntercept) + .setFamily(family) + .setStandardization(standardization) + .setWeightCol(weightCol) + .setAggregationDepth(aggregationDepth) + .setFeaturesCol(rFormula.getFeaturesCol) + .setProbabilityCol(probability) + + if (thresholds.length > 1) { + logisticRegression.setThresholds(thresholds) + } else { + logisticRegression.setThreshold(thresholds(0)) + } + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, logisticRegression)) + .fit(data) + + new LogisticRegressionWrapper(pipeline, features) + } + + override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader + + override def load(path: String): LogisticRegressionWrapper = super.load(path) + + class LogisticRegressionWrapperWriter(instance: LogisticRegressionWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class LogisticRegressionWrapperReader extends MLReader[LogisticRegressionWrapper] { + + override def load(path: String): LogisticRegressionWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + + val pipeline = PipelineModel.load(pipelinePath) + new LogisticRegressionWrapper(pipeline, features, isLoaded = true) + } + } +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 10673003534e6..2193eb80e9fdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -24,6 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,8 @@ private[r] object MultilayerPerceptronClassifierWrapper maxIter: Int, tol: Double, stepSize: Double, - seed: String + seed: String, + initialWeights: Array[Double] ): MultilayerPerceptronClassifierWrapper = { // get labels and feature names from output schema val schema = data.schema @@ -73,6 +75,11 @@ private[r] object MultilayerPerceptronClassifierWrapper .setStepSize(stepSize) .setPredictionCol(PREDICTED_LABEL_COL) if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) + if (initialWeights != null) { + require(initialWeights.length > 0) + mlp.setInitialWeights(Vectors.dense(initialWeights)) + } + val pipeline = new Pipeline() .setStages(Array(mlp)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index d64de1b6abb63..0e09e18027ca7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -54,6 +54,12 @@ private[r] object RWrappers extends MLReader[Object] { GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) + case "org.apache.spark.ml.r.LogisticRegressionWrapper" => + LogisticRegressionWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => + RandomForestRegressorWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => + RandomForestClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala new file mode 100644 index 0000000000000..b0088ddaf3b1d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestClassificationModel = + pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this) +} + +private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + probabilityCol: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfc = new RandomForestClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setProbabilityCol(probabilityCol) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc)) + .fit(data) + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestClassifierWrapper] = + new RandomForestClassifierWrapperReader + + override def load(path: String): RandomForestClassifierWrapper = super.load(path) + + class RandomForestClassifierWrapperWriter(instance: RandomForestClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] { + + override def load(path: String): RandomForestClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala new file mode 100644 index 0000000000000..c8874407fa75e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestRegressionModel = + pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this) +} + +private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new RandomForestRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestRegressorWrapper] = new RandomForestRegressorWrapperReader + + override def load(path: String): RandomForestRegressorWrapper = super.load(path) + + class RandomForestRegressorWrapperWriter(instance: RandomForestRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] { + + override def load(path: String): RandomForestRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bb01f9d5a364c..fa69d60836e68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -123,9 +123,16 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bb9e150c49772..33cb25c8c7f66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -262,7 +262,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val if (familyObj == Gaussian && linkObj == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. - val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) val model = copyValues( @@ -337,7 +337,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine Instance(eta, instance.weight, instance.features) } // TODO: Make standardizeFeatures and standardizeLabel configurable. - val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) .fit(newInstances) initialModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 025ed20c75a04..519f3bdec82df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ -import org.apache.spark.ml.optim.WeightedLeastSquares +import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -177,6 +177,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * If the dimensions of features or the number of partitions are large, * this param could be adjusted to a larger size. * Default is 2. + * * @group expertSetParam */ @Since("2.1.0") @@ -194,21 +195,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + if (($(solver) == "auto" && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { - require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + - "solver is used.'") - // For low dimensional data, WeightedLeastSquares is more efficiently since the + // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), - $(standardization), true) + elasticNetParam = $(elasticNetParam), $(standardization), true, + solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) val model = optimizer.fit(instances) // When it is trained by WeightedLeastSquares, training summary does not - // attached returned model. + // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) - // WeightedLeastSquares does not run through iterations. So it does not generate - // an objective history. val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), @@ -217,7 +215,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - Array(0D)) + model.objectiveHistory) return lrModel.setSummary(trainingSummary) } @@ -243,7 +241,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) if (rawYStd == 0.0) { - if ($(fitIntercept) || yMean==0.0) { + if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with // zero coefficient; as a result, training is not needed. // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 8577803743c8e..5e9e6ff1a5690 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -35,26 +35,29 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - path: String, + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString + } + private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + new Path(path) } }.getRecordWriter(context) } @@ -132,12 +135,11 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour dataSchema: StructType): OutputWriterFactory = { new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") } - new LibSVMOutputWriter(path, dataSchema, context) + new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 4413fefdea3ca..bc4f9e6716ee8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -474,7 +474,7 @@ private[ml] object MetaAlgorithmReadWrite { case ovr: OneVsRest => Array(ovr.getClassifier) case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models case rformModel: RFormulaModel => Array(rformModel.pipelineModel) - case _: Params => Array() + case _: Params => Array.empty[Params] } val subStageMaps = subStages.flatMap(getUidMapImpl) List((instance.uid, instance)) ++ subStageMaps diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 68a7b3b6763af..ed9c064879d01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,13 +56,15 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). + * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Default: 2. + * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") def setK(k: Int): this.type = { @@ -323,7 +325,10 @@ class KMeans private ( * Initialize a set of cluster centers at random. */ private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { - data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) + // Select without replacement; may still produce duplicates if the data has < k distinct + // points, so deduplicate the centroids to match the behavior of k-means|| in the same situation + data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()) + .map(_.vector).distinct.map(new VectorWithNorm(_)) } /** @@ -335,7 +340,7 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -378,19 +383,21 @@ class KMeans private ( costs.unpersist(blocking = false) bcNewCentersList.foreach(_.destroy(false)) - if (centers.size == k) { - centers.toArray + val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)) + + if (distinctCenters.size <= k) { + distinctCenters.toArray } else { - // Finally, we might have a set of more or less than k candidate centers; weight each + // Finally, we might have a set of more than k distinct candidate centers; weight each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them - val bcCenters = data.context.broadcast(centers) + val bcCenters = data.context.broadcast(distinctCenters) val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() bcCenters.destroy(blocking = false) - val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray - LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30) + val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray + LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index ce4421515126c..8f777cc35b93f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -73,7 +73,7 @@ class RegressionMetrics @Since("2.0.0") ( /** * Returns the variance explained by regression. - * explainedVariance = $\sum_i (\hat{y_i} - \bar{y})^2 / n$ + * explainedVariance = $\sum_i (\hat{y_i} - \bar{y})^2^ / n$ * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c305b36278e87..f8276de4f23d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -234,11 +234,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val features = selectorType match { case ChiSqSelector.KBest => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) case ChiSqSelector.Percentile => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 08f8f19c1e77d..68771f1afbe8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -20,6 +20,8 @@ package org.apache.spark.mllib.linalg import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW +import org.apache.spark.ml.optim.SingularMatrixException + /** * Compute Cholesky decomposition. */ @@ -60,7 +62,7 @@ private[spark] object CholeskyDecomposition { case code if code < 0 => throw new IllegalStateException(s"LAPACK.$method returned $code; arg ${-code} is illegal") case code if code > 0 => - throw new IllegalArgumentException( + throw new SingularMatrixException ( s"LAPACK.$method returned $code because A is not positive definite. Is A derived from " + "a singular matrix (e.g. collinear column values)?") case _ => // do nothing diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index ff1068417d94f..377be6bfb9886 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -429,14 +429,14 @@ class BlockMatrix @Since("1.3.0") ( val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2)) val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => - val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array()) + val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array.empty[Int]) val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b))) ((rowIndex, colIndex), partitions.toSet) }.toMap val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1)) val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => - val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array()) + val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array.empty[Int]) val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex))) ((rowIndex, colIndex), partitions.toSet) }.toMap diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index c3de5d75f4f7d..a8b5955a7285d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -124,7 +124,8 @@ private[stat] object KolmogorovSmirnovTest extends Logging { val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) => (math.min(pMin, dl), math.max(pMax, dp), pCt + 1) } - val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults) + val results = + if (pResults == initAcc) Array.empty[(Double, Double, Double)] else Array(pResults) results.iterator } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 42b56754e0835..bc631dc6d3149 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -25,14 +25,14 @@ import scala.util.control.Breaks._ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType class LogisticRegressionSuite @@ -40,6 +40,7 @@ class LogisticRegressionSuite import testImplicits._ + private val seed = 42 @transient var smallBinaryDataset: Dataset[_] = _ @transient var smallMultinomialDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ @@ -49,7 +50,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42).toDF() + smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = seed).toDF() smallMultinomialDataset = { val nPoints = 100 @@ -61,7 +62,7 @@ class LogisticRegressionSuite val xVariance = Array(0.6856, 0.1899) val testData = generateMultinomialLogisticInput( - coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) val df = sc.parallelize(testData, 4).toDF() df.cache() @@ -76,9 +77,9 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput(coefficients, xMean, xVariance, - addIntercept = true, nPoints, 42) + addIntercept = true, nPoints, seed) - sc.parallelize(testData, 4).toDF() + sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) } multinomialDataset = { @@ -91,9 +92,9 @@ class LogisticRegressionSuite val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val testData = generateMultinomialLogisticInput( - coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - val df = sc.parallelize(testData, 4).toDF() + val df = sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) df.cache() df } @@ -104,11 +105,11 @@ class LogisticRegressionSuite * so we can validate the training accuracy compared with R's glmnet package. */ ignore("export test data into CSV format") { - binaryDataset.rdd.map { case Row(label: Double, features: Vector) => - label + "," + features.toArray.mkString(",") + binaryDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) => + label + "," + weight + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/binaryDataset") - multinomialDataset.rdd.map { case Row(label: Double, features: Vector) => - label + "," + features.toArray.mkString(",") + multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) => + label + "," + weight + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset") } @@ -519,31 +520,35 @@ class LogisticRegressionSuite test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) + .setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) + .setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - coefficients + Use the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 0)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.7355261 + data.V3 -0.5734389 + data.V4 0.8911736 + data.V5 -0.3878645 + data.V6 -0.8060570 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 2.8366423 - data.V2 -0.5895848 - data.V3 0.8931147 - data.V4 -0.3925051 - data.V5 -0.7996864 */ - val interceptR = 2.8366423 - val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val coefficientsR = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptR = 2.7355261 assert(model1.intercept ~== interceptR relTol 1E-3) assert(model1.coefficients ~= coefficientsR relTol 1E-3) @@ -555,413 +560,374 @@ class LogisticRegressionSuite test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) + .setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false) + .setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = - coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 0, intercept=FALSE)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 -0.3448461 + data.V4 1.2776453 + data.V5 -0.3539178 + data.V6 -0.7469384 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.3534996 - data.V3 1.2964482 - data.V4 -0.3571741 - data.V5 -0.7407946 */ - val interceptR = 0.0 - val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val coefficientsR = Vectors.dense(-0.3448461, 1.2776453, -0.3539178, -0.7469384) - assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.intercept ~== 0.0 relTol 1E-3) assert(model1.coefficients ~= coefficientsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. - assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, standardize=T)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.06775980 + data.V3 . + data.V4 . + data.V5 -0.03933146 + data.V6 -0.03047580 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) -0.05627428 - data.V2 . - data.V3 . - data.V4 -0.04325749 - data.V5 -0.02481551 */ - val interceptR1 = -0.05627428 - val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + val coefficientsRStd = Vectors.dense(0.0, 0.0, -0.03933146, -0.03047580) + val interceptRStd = -0.06775980 - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 absTol 2E-2) + assert(model1.intercept ~== interceptRStd relTol 1E-2) + assert(model1.coefficients ~= coefficientsRStd absTol 2E-2) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - standardize=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, standardize=F)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.3544768 + data.V3 . + data.V4 . + data.V5 -0.1626191 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.3722152 - data.V2 . - data.V3 . - data.V4 -0.1665453 - data.V5 . */ - val interceptR2 = 0.3722152 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + val coefficientsR = Vectors.dense(0.0, 0.0, -0.1626191, 0.0) + val interceptR = 0.3544768 - assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.coefficients ~== coefficientsR2 absTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-2) + assert(model2.coefficients ~== coefficientsR absTol 1E-3) // TODO: move this to a standalone test of compression after SPARK-17471 assert(model2.coefficients.isInstanceOf[SparseVector]) } test("binary logistic regression without intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - intercept=FALSE)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 . - data.V4 -0.05189203 - data.V5 -0.03891782 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 absTol 1E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, intercept=F, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 . + data.V5 -0.04967635 + data.V6 -0.04757757 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - intercept=FALSE, standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 . + data.V5 -0.08433195 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 . - data.V4 -0.08420782 - data.V5 . */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + val coefficientsRStd = Vectors.dense(0.0, 0.0, -0.04967635, -0.04757757) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + val coefficientsR = Vectors.dense(0.0, 0.0, -0.08433195, 0.0) + + assert(model1.intercept ~== 0.0 absTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd absTol 1E-3) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.15021751 - data.V2 -0.07251837 - data.V3 0.10724191 - data.V4 -0.04865309 - data.V5 -0.10062872 - */ - val interceptR1 = 0.15021751 - val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.12707703 + data.V3 -0.06980967 + data.V4 0.10803933 + data.V5 -0.04800404 + data.V6 -0.10165096 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.46613016 + data.V3 -0.04944529 + data.V4 0.02326772 + data.V5 -0.11362772 + data.V6 -0.06312848 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.48657516 - data.V2 -0.05155371 - data.V3 0.02301057 - data.V4 -0.11482896 - data.V5 -0.06266838 */ - val interceptR2 = 0.48657516 - val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val coefficientsRStd = Vectors.dense(-0.06980967, 0.10803933, -0.04800404, -0.10165096) + val interceptRStd = 0.12707703 + val coefficientsR = Vectors.dense(-0.04944529, 0.02326772, -0.11362772, -0.06312848) + val interceptR = 0.46613016 - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model1.intercept ~== interceptRStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd relTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - intercept=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, intercept=F, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 -0.06000152 + data.V4 0.12598737 + data.V5 -0.04669009 + data.V6 -0.09941025 - 5 x 1 sparse Matrix of class "dgCMatrix" + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - data.V2 -0.06099165 - data.V3 0.12857058 - data.V4 -0.04708770 - data.V5 -0.09799775 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) - - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - intercept=FALSE, standardize=FALSE)) - coefficients + (Intercept) . + data.V3 -0.005482255 + data.V4 0.048106338 + data.V5 -0.093411640 + data.V6 -0.054149798 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.005679651 - data.V3 0.048967094 - data.V4 -0.093714016 - data.V5 -0.053314311 */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val coefficientsRStd = Vectors.dense(-0.06000152, 0.12598737, -0.04669009, -0.09941025) + val coefficientsR = Vectors.dense(-0.005482255, 0.048106338, -0.093411640, -0.054149798) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + assert(model1.intercept ~== 0.0 absTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd relTol 1E-2) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.57734851 - data.V2 -0.05310287 - data.V3 . - data.V4 -0.08849250 - data.V5 -0.15458796 - */ - val interceptR1 = 0.57734851 - val coefficientsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) - - assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.coefficients ~== coefficientsR1 absTol 5E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.49991996 + data.V3 -0.04131110 + data.V4 . + data.V5 -0.08585233 + data.V6 -0.15875400 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.5024256 + data.V3 . + data.V4 . + data.V5 -0.1846038 + data.V6 -0.0559614 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.51555993 - data.V2 . - data.V3 . - data.V4 -0.18807395 - data.V5 -0.05350074 */ - val interceptR2 = 0.51555993 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) - - assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + val coefficientsRStd = Vectors.dense(-0.04131110, 0.0, -0.08585233, -0.15875400) + val interceptRStd = 0.49991996 + val coefficientsR = Vectors.dense(0.0, 0.0, -0.1846038, -0.0559614) + val interceptR = 0.5024256 + + assert(model1.intercept ~== interceptRStd relTol 6E-3) + assert(model1.coefficients ~== coefficientsRStd absTol 5E-3) + assert(model2.intercept ~== interceptR relTol 6E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - intercept=FALSE)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.001005743 - data.V3 0.072577857 - data.V4 -0.081203769 - data.V5 -0.142534158 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 absTol 1E-2) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, intercept=FALSE, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, intercept=FALSE, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 0.06859390 + data.V5 -0.07900058 + data.V6 -0.14684320 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - intercept=FALSE, standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 0.03060637 + data.V5 -0.11126742 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 0.03345223 - data.V4 -0.11304532 - data.V5 . */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) + val coefficientsRStd = Vectors.dense(0.0, 0.06859390, -0.07900058, -0.14684320) + val coefficientsR = Vectors.dense(0.0, 0.03060637, -0.11126742, 0.0) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd absTol 1E-2) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false) val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.as[Instance].rdd.map { i => (i.label, i.weight)} .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { - case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + case (classSummarizer: MultiClassSummarizer, (label: Double, weight: Double)) => + classSummarizer.add(label, weight) }, combOp = (c1, c2) => (c1, c2) match { case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => @@ -989,25 +955,26 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsTheory absTol 1E-6) /* - TODO: why is this needed? The correctness of L1 regularization is already checked elsewhere Using the following R code to load the data and train the model using glmnet package. library("glmnet") data <- read.csv("path", header=FALSE) label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1.0, + lambda = 6.0)) coefficients 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) -0.2480643 - data.V2 0.0000000 - data.V3 . - data.V4 . - data.V5 . + s0 + (Intercept) -0.2516986 + data.V3 0.0000000 + data.V4 . + data.V5 . + data.V6 . */ - val interceptR = -0.248065 + val interceptR = -0.2516986 val coefficientsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) @@ -1015,9 +982,9 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept with strong L1 regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false) val sqlContext = multinomialDataset.sqlContext @@ -1025,16 +992,17 @@ class LogisticRegressionSuite val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) - val histogram = multinomialDataset.as[LabeledPoint].rdd.map(_.label) + val histogram = multinomialDataset.as[Instance].rdd.map(i => (i.label, i.weight)) .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { - case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + case (classSummarizer: MultiClassSummarizer, (label: Double, weight: Double)) => + classSummarizer.add(label, weight) }, combOp = (c1, c2) => (c1, c2) match { case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => classSummarizer1.merge(classSummarizer2) }).histogram - val numFeatures = multinomialDataset.as[LabeledPoint].first().features.size + val numFeatures = multinomialDataset.as[Instance].first().features.size val numClasses = histogram.length /* @@ -1068,52 +1036,58 @@ class LogisticRegressionSuite test("multinomial logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setMaxIter(100) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = as.factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0)) - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -2.24493379 - V2 0.25096771 - V3 -0.03915938 - V4 0.14766639 - V5 0.36810817 - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.3778931 - V2 -0.3327489 - V3 0.8893666 - V4 -0.2306948 - V5 -0.4442330 - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.86704066 - V2 0.08178121 - V3 -0.85020722 - V4 0.08302840 - V5 0.07612480 - */ + Use the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 0, lambda = 0)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -2.10320093 + data.V3 0.24337896 + data.V4 -0.05916156 + data.V5 0.14446790 + data.V6 0.35976165 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.3394473 + data.V3 -0.3443375 + data.V4 0.9181331 + data.V5 -0.2283959 + data.V6 -0.4388066 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 1.76375361 + data.V3 0.10095851 + data.V4 -0.85897154 + data.V5 0.08392798 + data.V6 0.07904499 + + + */ val coefficientsR = new DenseMatrix(3, 4, Array( - 0.2509677, -0.0391594, 0.1476664, 0.3681082, - -0.3327489, 0.8893666, -0.2306948, -0.4442330, - 0.0817812, -0.8502072, 0.0830284, 0.0761248), isTransposed = true) - val interceptsR = Vectors.dense(-2.2449338, 0.3778931, 1.8670407) + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) @@ -1128,52 +1102,57 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Using the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0, - intercept=F)) - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.06992464 - V3 -0.36562784 - V4 0.12142680 - V5 0.32052211 - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.3036269 - V3 0.9449630 - V4 -0.2271038 - V5 -0.4364839 - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.2337022 - V3 -0.5793351 - V4 0.1056770 - V5 0.1159618 - */ + Use the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0, intercept=F)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 0.07276291 + data.V4 -0.36325496 + data.V5 0.12015088 + data.V6 0.31397340 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 -0.3180040 + data.V4 0.9679074 + data.V5 -0.2252219 + data.V6 -0.4319914 + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 0.2452411 + data.V4 -0.6046524 + data.V5 0.1050710 + data.V6 0.1180180 + + + */ val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0699246, -0.3656278, 0.1214268, 0.3205221, - -0.3036269, 0.9449630, -0.2271038, -0.4364839, - 0.2337022, -0.5793351, 0.1056770, 0.1159618), isTransposed = true) + 0.07276291, -0.36325496, 0.12015088, 0.31397340, + -0.3180040, 0.9679074, -0.2252219, -0.4319914, + 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) @@ -1190,92 +1169,95 @@ class LogisticRegressionSuite // use tighter constraints because OWL-QN solver takes longer to converge val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(300).setTol(1e-10).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(300).setTol(1e-10).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1, - lambda = 0.05, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05, - standardization=F)) - > coefficientsStd - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.68988825 - V2 . - V3 . - V4 . - V5 0.09404023 - - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2303499 - V2 -0.1232443 - V3 0.3258380 - V4 -0.1564688 - V5 -0.2053965 - - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.9202381 - V2 . - V3 -0.4803856 - V4 . - V5 . - - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.44893320 - V2 . - V3 . - V4 0.01933812 - V5 0.03666044 - - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.7376760 - V2 -0.0577182 - V3 . - V4 -0.2081718 - V5 -0.1304592 - - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2887428 - V2 . - V3 . - V4 . - V5 . - */ + Use the following R code to load the data and train the model using glmnet package. - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.09404023, - -0.1232443, 0.3258380, -0.1564688, -0.2053965, - 0.0, -0.4803856, 0.0, 0.0), isTransposed = true) - val interceptsRStd = Vectors.dense(-0.68988825, -0.2303499, 0.9202381) + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 1, lambda = 0.05, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, standardize=F)) + coefficientsStd + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.62244703 + data.V3 . + data.V4 . + data.V5 . + data.V6 0.08419825 + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.2804845 + data.V3 -0.1336960 + data.V4 0.3717091 + data.V5 -0.1530363 + data.V6 -0.2035286 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.9029315 + data.V3 . + data.V4 -0.4629737 + data.V5 . + data.V6 . + + + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.44215290 + data.V3 . + data.V4 . + data.V5 0.01767089 + data.V6 0.02542866 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.76308326 + data.V3 -0.06818576 + data.V4 . + data.V5 -0.20446351 + data.V6 -0.13017924 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.3209304 + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.0, 0.0, 0.0, 0.08419825, + -0.1336960, 0.3717091, -0.1530363, -0.2035286, + 0.0, -0.4629737, 0.0, 0.0), isTransposed = true) + val interceptsRStd = Vectors.dense(-0.62244703, -0.2804845, 0.9029315) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.01933812, 0.03666044, - -0.0577182, 0.0, -0.2081718, -0.1304592, + 0.0, 0.0, 0.01767089, 0.02542866, + -0.06818576, 0.0, -0.20446351, -0.13017924, 0.0, 0.0, 0.0, 0.0), isTransposed = true) - val interceptsR = Vectors.dense(-0.44893320, 0.7376760, -0.2887428) + val interceptsR = Vectors.dense(-0.44215290, 0.76308326, -0.3209304) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.02) assert(model1.interceptVector ~== interceptsRStd relTol 0.1) @@ -1287,87 +1269,91 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1, - lambda = 0.05, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05, - intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 0.01525105 + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 0.01144225 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.1502410 - V3 0.5134658 - V4 -0.1601146 - V5 -0.2500232 + s0 + . + data.V3 -0.1678787 + data.V4 0.5385351 + data.V5 -0.1573039 + data.V6 -0.2471624 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.003301875 - V3 . - V4 . - V5 . - - > coefficients + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 0.1943624 - V4 -0.1902577 - V5 -0.1028789 + s0 + . + data.V3 . + data.V4 0.1929409 + data.V5 -0.1889121 + data.V6 -0.1010413 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . - */ + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.01525105, - -0.1502410, 0.5134658, -0.1601146, -0.2500232, - 0.003301875, 0.0, 0.0, 0.0), isTransposed = true) + 0.0, 0.0, 0.0, 0.01144225, + -0.1678787, 0.5385351, -0.1573039, -0.2471624, + 0.0, 0.0, 0.0, 0.0), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( 0.0, 0.0, 0.0, 0.0, - 0.0, 0.1943624, -0.1902577, -0.1028789, + 0.0, 0.1929409, -0.1889121, -0.1010413, 0.0, 0.0, 0.0, 0.0), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) @@ -1380,92 +1366,95 @@ class LogisticRegressionSuite test("multinomial logistic regression with intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=T, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=T, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame( data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 0, lambda = 0.1, intercept=T, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=T, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -1.70040424 - V2 0.17576070 - V3 0.01527894 - V4 0.10216108 - V5 0.26099531 + s0 + -1.5898288335 + data.V3 0.1691226336 + data.V4 0.0002983651 + data.V5 0.1001732896 + data.V6 0.2554575585 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.2438590 - V2 -0.2238875 - V3 0.5967610 - V4 -0.1555496 - V5 -0.3010479 + s0 + 0.2125746 + data.V3 -0.2304586 + data.V4 0.6153492 + data.V5 -0.1537017 + data.V6 -0.2975443 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.45654525 - V2 0.04812679 - V3 -0.61203992 - V4 0.05338850 - V5 0.04005258 - - > coefficients + s0 + 1.37725427 + data.V3 0.06133600 + data.V4 -0.61564761 + data.V5 0.05352840 + data.V6 0.04208671 + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -1.65488543 - V2 0.15715048 - V3 0.01992903 - V4 0.12428858 - V5 0.22130317 + s0 + -1.5681088 + data.V3 0.1508182 + data.V4 0.0121955 + data.V5 0.1217930 + data.V6 0.2162850 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.1297533 - V2 -0.1974768 - V3 0.2776373 - V4 -0.1869445 - V5 -0.2510320 + s0 + 1.1217130 + data.V3 -0.2028984 + data.V4 0.2862431 + data.V5 -0.1843559 + data.V6 -0.2481218 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.52513212 - V2 0.04032627 - V3 -0.29756637 - V4 0.06265594 - V5 0.02972883 - */ + s0 + 0.44639579 + data.V3 0.05208012 + data.V4 -0.29843864 + data.V5 0.06256289 + data.V6 0.03183676 - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.17576070, 0.01527894, 0.10216108, 0.26099531, - -0.2238875, 0.5967610, -0.1555496, -0.3010479, - 0.04812679, -0.61203992, 0.05338850, 0.04005258), isTransposed = true) - val interceptsRStd = Vectors.dense(-1.70040424, 0.2438590, 1.45654525) + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.1691226336, 0.0002983651, 0.1001732896, 0.2554575585, + -0.2304586, 0.6153492, -0.1537017, -0.2975443, + 0.06133600, -0.61564761, 0.05352840, 0.04208671), isTransposed = true) + val interceptsRStd = Vectors.dense(-1.5898288335, 0.2125746, 1.37725427) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.15715048, 0.01992903, 0.12428858, 0.22130317, - -0.1974768, 0.2776373, -0.1869445, -0.2510320, - 0.04032627, -0.29756637, 0.06265594, 0.02972883), isTransposed = true) - val interceptsR = Vectors.dense(-1.65488543, 1.1297533, 0.52513212) + 0.1508182, 0.0121955, 0.1217930, 0.2162850, + -0.2028984, 0.2862431, -0.1843559, -0.2481218, + 0.05208012, -0.29843864, 0.06256289, 0.03183676), isTransposed = true) + val interceptsR = Vectors.dense(-1.5681088, 1.1217130, 0.44639579) - assert(model1.coefficientMatrix ~== coefficientsRStd relTol 0.05) + assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.001) assert(model1.interceptVector ~== interceptsRStd relTol 0.05) assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps) assert(model2.coefficientMatrix ~== coefficientsR relTol 0.05) @@ -1475,86 +1464,92 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.03904171 - V3 -0.23354322 - V4 0.08288096 - V5 0.22706393 + s0 + . + data.V3 0.04048126 + data.V4 -0.23075758 + data.V5 0.08228864 + data.V6 0.22277648 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.2061848 - V3 0.6341398 - V4 -0.1530059 - V5 -0.2958455 + s0 + . + data.V3 -0.2149745 + data.V4 0.6478666 + data.V5 -0.1515158 + data.V6 -0.2930498 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.16714312 - V3 -0.40059658 - V4 0.07012496 - V5 0.06878158 - > coefficients + s0 + . + data.V3 0.17449321 + data.V4 -0.41710901 + data.V5 0.06922716 + data.V6 0.07027332 + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.005704542 - V3 -0.144466409 - V4 0.092080736 - V5 0.182927657 + s0 + . + data.V3 -0.003949652 + data.V4 -0.142982415 + data.V5 0.091439598 + data.V6 0.179286241 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.08469036 - V3 0.38996748 - V4 -0.16468436 - V5 -0.22522976 + s0 + . + data.V3 -0.09071124 + data.V4 0.39752531 + data.V5 -0.16233832 + data.V6 -0.22206059 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.09039490 - V3 -0.24550107 - V4 0.07260362 - V5 0.04230210 + s0 + . + data.V3 0.09466090 + data.V4 -0.25454290 + data.V5 0.07089872 + data.V6 0.04277435 + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.03904171, -0.23354322, 0.08288096, 0.2270639, - -0.2061848, 0.6341398, -0.1530059, -0.2958455, - 0.16714312, -0.40059658, 0.07012496, 0.06878158), isTransposed = true) + 0.04048126, -0.23075758, 0.08228864, 0.22277648, + -0.2149745, 0.6478666, -0.1515158, -0.2930498, + 0.17449321, -0.41710901, 0.06922716, 0.07027332), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( - -0.005704542, -0.144466409, 0.092080736, 0.182927657, - -0.08469036, 0.38996748, -0.16468436, -0.22522976, - 0.0903949, -0.24550107, 0.07260362, 0.0423021), isTransposed = true) + -0.003949652, -0.142982415, 0.091439598, 0.179286241, + -0.09071124, 0.39752531, -0.16233832, -0.22206059, + 0.09466090, -0.25454290, 0.07089872, 0.04277435), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) @@ -1565,10 +1560,10 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept with elasticnet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) .setMaxIter(300).setTol(1e-10) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) .setMaxIter(300).setTol(1e-10) @@ -1576,82 +1571,85 @@ class LogisticRegressionSuite val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=T, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=T, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=T, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=T, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.5521819483 - V2 0.0003092611 - V3 . - V4 . - V5 0.0913818490 + s0 + -0.50133383 + data.V3 . + data.V4 . + data.V5 . + data.V6 0.08351653 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.27531989 - V2 -0.09790029 - V3 0.28502034 - V4 -0.12416487 - V5 -0.16513373 + s0 + -0.3151913 + data.V3 -0.1058702 + data.V4 0.3183251 + data.V5 -0.1212969 + data.V6 -0.1629778 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.8275018 - V2 . - V3 -0.4044859 - V4 . - V5 . - - > coefficients + s0 + 0.8165252 + data.V3 . + data.V4 -0.3943069 + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.39876213 - V2 . - V3 . - V4 0.02547520 - V5 0.03893991 + s0 + -0.38857157 + data.V3 . + data.V4 . + data.V5 0.02384198 + data.V6 0.03127749 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.61089869 - V2 -0.04224269 - V3 . - V4 -0.18923970 - V5 -0.09104249 + s0 + 0.62492165 + data.V3 -0.04949061 + data.V4 . + data.V5 -0.18584462 + data.V6 -0.08952455 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2121366 - V2 . - V3 . - V4 . - V5 . - */ + s0 + -0.2363501 + data.V3 . + data.V4 . + data.V5 . + data.V6 . - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0003092611, 0.0, 0.0, 0.091381849, - -0.09790029, 0.28502034, -0.12416487, -0.16513373, - 0.0, -0.4044859, 0.0, 0.0), isTransposed = true) - val interceptsRStd = Vectors.dense(-0.5521819483, -0.27531989, 0.8275018) + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.0, 0.0, 0.0, 0.08351653, + -0.1058702, 0.3183251, -0.1212969, -0.1629778, + 0.0, -0.3943069, 0.0, 0.0), isTransposed = true) + val interceptsRStd = Vectors.dense(-0.50133383, -0.3151913, 0.8165252) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0254752, 0.03893991, - -0.04224269, 0.0, -0.1892397, -0.09104249, + 0.0, 0.0, 0.02384198, 0.03127749, + -0.04949061, 0.0, -0.18584462, -0.08952455, 0.0, 0.0, 0.0, 0.0), isTransposed = true) - val interceptsR = Vectors.dense(-0.39876213, 0.61089869, -0.2121366) + val interceptsR = Vectors.dense(-0.38857157, 0.62492165, -0.2363501) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) assert(model1.interceptVector ~== interceptsRStd absTol 0.01) @@ -1662,10 +1660,10 @@ class LogisticRegressionSuite } test("multinomial logistic regression without intercept with elasticnet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(false) + val trainer1 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) .setMaxIter(300).setTol(1e-10) - val trainer2 = (new LogisticRegression).setFitIntercept(false) + val trainer2 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) .setMaxIter(300).setTol(1e-10) @@ -1673,78 +1671,83 @@ class LogisticRegressionSuite val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 0.03543706 + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 0.03238285 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.1187387 - V3 0.4025482 - V4 -0.1270969 - V5 -0.1918386 + s0 + . + data.V3 -0.1328284 + data.V4 0.4219321 + data.V5 -0.1247544 + data.V6 -0.1893318 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.00774365 - V3 . - V4 . - V5 . - - > coefficients + s0 + . + data.V3 0.004572312 + data.V4 . + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 0.14666497 - V4 -0.16570638 - V5 -0.05982875 + s0 + . + data.V3 . + data.V4 0.14571623 + data.V5 -0.16456351 + data.V6 -0.05866264 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.03543706, - -0.1187387, 0.4025482, -0.1270969, -0.1918386, - 0.0, 0.0, 0.0, 0.00774365), isTransposed = true) + 0.0, 0.0, 0.0, 0.03238285, + -0.1328284, 0.4219321, -0.1247544, -0.1893318, + 0.004572312, 0.0, 0.0, 0.0), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( 0.0, 0.0, 0.0, 0.0, - 0.0, 0.14666497, -0.16570638, -0.05982875, + 0.0, 0.14571623, -0.16456351, -0.05866264, 0.0, 0.0, 0.0, 0.0), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index c08cb695806d0..41684d92be33a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -51,7 +51,7 @@ class MultilayerPerceptronClassifierSuite test("Input Validation") { val mlpc = new MultilayerPerceptronClassifier() intercept[IllegalArgumentException] { - mlpc.setLayers(Array[Int]()) + mlpc.setLayers(Array.empty[Int]) } intercept[IllegalArgumentException] { mlpc.setLayers(Array[Int](1)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 4f7d4418a8d09..f2368a9f8dad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -68,7 +68,7 @@ class BisectingKMeansSuite } } - test("fit & transform") { + test("fit, transform and summary") { val predictionColName = "bisecting_kmeans_prediction" val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) @@ -85,6 +85,22 @@ class BisectingKMeansSuite assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: BisectingKMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 04366f5250287..003fa6abf6597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "gm_prediction" val probabilityColName = "gm_probability" val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c9ba5a288aadf..ca392653557c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 87cdceb267387..aac29137d7911 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -99,21 +99,32 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) + bucketizer.setHandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") } + + bucketizer.setHandleInvalid("skip") + val skipResults: Array[Double] = bucketizer.transform(dataFrame) + .select("result").as[Double].collect() + assert(skipResults.length === 7) + assert(skipResults.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } } test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) - withClue("Invalid NaN split was not caught as an invalid split!") { + withClue("Invalid NaN split was not caught during Bucketizer initialization") { intercept[IllegalArgumentException] { - val bucketizer: Bucketizer = new Bucketizer() - .setInputCol("feature") - .setOutputCol("result") - .setSplits(splits) + new Bucketizer().setSplits(splits) } } } @@ -138,7 +149,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val data = Array.fill(100)(Random.nextDouble()) val splits: Array[Double] = Double.NegativeInfinity +: Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity - val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val bsResult = Vectors.dense(data.map(x => + Bucketizer.binarySearchForBuckets(splits, x, false))) val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } @@ -169,7 +181,7 @@ private object BucketizerSuite extends SparkFunSuite { /** Check all values in splits, plus values between all splits. */ def checkBinarySearch(splits: Array[Double]): Unit = { def testFeature(feature: Double, expectedBucket: Double): Unit = { - assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket, s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + s" ${splits.mkString(", ")}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index dfebfc87ea1d3..6af06d82d671a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -38,10 +38,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext ) val preFilteredData = Seq( - Vectors.dense(0.0), - Vectors.dense(6.0), Vectors.dense(8.0), - Vectors.dense(5.0) + Vectors.dense(0.0), + Vectors.dense(0.0), + Vectors.dense(8.0) ) val df = sc.parallelize(data.zip(preFilteredData)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala new file mode 100644 index 0000000000000..5c025546f332b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DataTypes + +private[ml] object LSHTest { + /** + * For any locality sensitive function h in a metric space, we meed to verify whether + * the following property is satisfied. + * + * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, + * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 + * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * + * This is called locality sensitive property. This method checks the property on an + * existing dataset and calculate the probabilities. + * (https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Definition) + * + * This method hashes each elements to hash buckets using LSH, and calculate the false positive + * and false negative: + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * + * @param dataset The dataset to verify the locality sensitive hashing property. + * @param lsh The lsh instance to perform the hashing + * @param distFP Distance threshold for false positive + * @param distFN Distance threshold for false negative + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing the false positive and false negative rate + */ + def calculateLSHProperty[T <: LSHModel[T]]( + dataset: Dataset[_], + lsh: LSH[T], + distFP: Double, + distFN: Double): (Double, Double) = { + val model = lsh.fit(dataset) + val inputCol = model.getInputCol + val outputCol = model.getOutputCol + val transformedData = model.transform(dataset) + + SchemaUtils.checkColumnType(transformedData.schema, model.getOutputCol, new VectorUDT) + + // Perform a cross join and label each pair of same_bucket and distance + val pairs = transformedData.as("a").crossJoin(transformedData.as("b")) + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) + val sameBucket = udf((x: Vector, y: Vector) => model.hashDistance(x, y) == 0.0, + DataTypes.BooleanType) + val result = pairs + .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol"))) + .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol"))) + + // Compute the probabilities based on the join result + val positive = result.filter(col("same_bucket")) + val negative = result.filter(!col("same_bucket")) + val falsePositiveCount = positive.filter(col("distance") > distFP).count().toDouble + val falseNegativeCount = negative.filter(col("distance") < distFN).count().toDouble + (falsePositiveCount / positive.count(), falseNegativeCount / negative.count()) + } + + /** + * Compute the precision and recall of approximate nearest neighbors + * @param lsh The lsh instance + * @param dataset the dataset to look for the key + * @param key The key to hash for the item + * @param k The maximum number of items closest to the key + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing precision and recall rate + */ + def calculateApproxNearestNeighbors[T <: LSHModel[T]]( + lsh: LSH[T], + dataset: Dataset[_], + key: Vector, + k: Int, + singleProbing: Boolean): (Double, Double) = { + val model = lsh.fit(dataset) + + // Compute expected + val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType) + val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k) + + // Compute actual + val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol") + + assert(actual.schema.sameType(model + .transformSchema(dataset.schema) + .add("distCol", DataTypes.DoubleType)) + ) + + if (!singleProbing) { + assert(actual.count() == k) + } + + // Compute precision and recall + val correctCount = expected.join(actual, model.getInputCol).count().toDouble + (correctCount / actual.count(), correctCount / expected.count()) + } + + /** + * Compute the precision and recall of approximate similarity join + * @param lsh The lsh instance + * @param datasetA One of the datasets to join + * @param datasetB Another dataset to join + * @param threshold The threshold for the distance of record pairs + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing precision and recall rate + */ + def calculateApproxSimilarityJoin[T <: LSHModel[T]]( + lsh: LSH[T], + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double): (Double, Double) = { + val model = lsh.fit(datasetA) + val inputCol = model.getInputCol + + // Compute expected + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) + val expected = datasetA.as("a").crossJoin(datasetB.as("b")) + .filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold) + + // Compute actual + val actual = model.approxSimilarityJoin(datasetA, datasetB, threshold) + + SchemaUtils.checkColumnType(actual.schema, "distCol", DataTypes.DoubleType) + assert(actual.schema.apply("datasetA").dataType + .sameType(model.transformSchema(datasetA.schema))) + assert(actual.schema.apply("datasetB").dataType + .sameType(model.transformSchema(datasetB.schema))) + + // Compute precision and recall + val correctCount = actual.filter(col("distCol") < threshold).count().toDouble + (correctCount / actual.count(), correctCount / expected.count()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala new file mode 100644 index 0000000000000..c32ca7d69cf84 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val data = { + for (i <- 0 to 95) yield Vectors.sparse(100, (i until i + 5).map((_, 1.0))) + } + dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + } + + test("params") { + ParamsSuite.checkParams(new MinHash) + val model = new MinHashModel("mh", numEntries = 2, randCoefficients = Array(1)) + ParamsSuite.checkParams(model) + } + + test("MinHash: default params") { + val rp = new MinHash + assert(rp.getOutputDim === 1.0) + } + + test("read/write") { + def checkModelData(model: MinHashModel, model2: MinHashModel): Unit = { + assert(model.numEntries === model2.numEntries) + assertResult(model.randCoefficients)(model2.randCoefficients) + } + val mh = new MinHash() + val settings = Map("inputCol" -> "keys", "outputCol" -> "values") + testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + } + + test("hashFunction") { + val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(0, 1, 3)) + val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) + assert(res.equals(Vectors.dense(0.0, 3.0, 4.0))) + } + + test("keyDistance and hashDistance") { + val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(1)) + val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))) + val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0))) + val keyDist = model.keyDistance(v1, v2) + val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2)) + assert(keyDist === 0.5) + assert(hashDist === 3) + } + + test("MinHash: test of LSH property") { + val mh = new MinHash() + .setOutputDim(1) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12344) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, mh, 0.75, 0.5) + assert(falsePositive < 0.3) + assert(falseNegative < 0.3) + } + + test("approxNearestNeighbors for min hash") { + val mh = new MinHash() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12345) + + val key: Vector = Vectors.sparse(100, + (0 until 100).filter(_.toString.contains("1")).map((_, 1.0))) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20, + singleProbing = true) + assert(precision >= 0.7) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for minhash on different dataset") { + val data1 = { + for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0))) + } + val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys") + + val data2 = { + for (i <- 0 until 30) yield Vectors.sparse(100, (3 * i until 3 * i + 3).map((_, 1.0))) + } + val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") + + val mh = new MinHash() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(mh, df1, df2, 0.5) + assert(precision == 1.0) + assert(recall >= 0.7) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6822594044a56..f219f775b2186 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite @@ -76,20 +76,33 @@ class QuantileDiscretizerSuite import spark.implicits._ val numBuckets = 3 - val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) - .map(Tuple1.apply).toDF("input") + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - // Reserve extra one bucket for NaN - val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData.toSeq.toDF("input") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ + case(u, v) => + discretizer.setHandleInvalid(u) + val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } test("Test transform method on unseen data") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala new file mode 100644 index 0000000000000..cd82ee2117a07 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import breeze.numerics.{cos, sin} +import breeze.numerics.constants.Pi + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +class RandomProjectionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val data = { + for (i <- -10 until 10; j <- -10 until 10) yield Vectors.dense(i.toDouble, j.toDouble) + } + dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + } + + test("params") { + ParamsSuite.checkParams(new RandomProjection) + val model = new RandomProjectionModel("rp", randUnitVectors = Array(Vectors.dense(1.0, 0.0))) + ParamsSuite.checkParams(model) + } + + test("RandomProjection: default params") { + val rp = new RandomProjection + assert(rp.getOutputDim === 1.0) + } + + test("read/write") { + def checkModelData(model: RandomProjectionModel, model2: RandomProjectionModel): Unit = { + model.randUnitVectors.zip(model2.randUnitVectors) + .foreach(pair => assert(pair._1 === pair._2)) + } + val mh = new RandomProjection() + val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) + testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + } + + test("hashFunction") { + val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0)) + val model = new RandomProjectionModel("rp", randUnitVectors) + model.set(model.bucketLength, 0.5) + val res = model.hashFunction(Vectors.dense(1.23, 4.56)) + assert(res.equals(Vectors.dense(9.0, 2.0))) + } + + test("keyDistance and hashDistance") { + val model = new RandomProjectionModel("rp", Array(Vectors.dense(0.0, 1.0))) + val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2)) + val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2)) + assert(keyDist === 5) + assert(hashDist === 3) + } + + test("RandomProjection: randUnitVectors") { + val rp = new RandomProjection() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val unitVectors = rp.fit(dataset).randUnitVectors + unitVectors.foreach { v: Vector => + assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) + } + } + + test("RandomProjection: test of LSH property") { + // Project from 2 dimensional Euclidean Space to 1 dimensions + val rp = new RandomProjection() + .setOutputDim(1) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, rp, 8.0, 2.0) + assert(falsePositive < 0.4) + assert(falseNegative < 0.4) + } + + test("RandomProjection with high dimension data: test of LSH property") { + val numDim = 100 + val data = { + for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2)) + yield Vectors.sparse(numDim, Seq((i, j.toDouble))) + } + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + + // Project from 100 dimensional Euclidean Space to 10 dimensions + val rp = new RandomProjection() + .setOutputDim(10) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(2.5) + .setSeed(12345) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, rp, 3.0, 2.0) + assert(falsePositive < 0.3) + assert(falseNegative < 0.3) + } + + test("approxNearestNeighbors for random projection") { + val key = Vectors.dense(1.2, 3.4) + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, + singleProbing = true) + assert(precision >= 0.6) + assert(recall >= 0.6) + } + + test("approxNearestNeighbors with multiple probing") { + val key = Vectors.dense(1.2, 3.4) + + val rp = new RandomProjection() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, + singleProbing = false) + assert(precision >= 0.7) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for random projection on different dataset") { + val data2 = { + for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i)) + } + val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, dataset, dataset2, 1.0) + assert(precision == 1.0) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for self join") { + val data = { + for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i)) + } + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, df, df, 3.0) + assert(precision == 1.0) + assert(recall >= 0.7) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 23464073e6edb..753f890c48301 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -43,6 +43,7 @@ class SQLTransformerSuite assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index b30d995794d4c..50260952ecb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -85,7 +85,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes val eta = math.log(mu / (1.0 - mu)) Instance(eta, instance.weight, instance.features) } - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(newInstances) val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) @@ -122,7 +122,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes val eta = math.log(mu) Instance(eta, instance.weight, instance.features) } - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(newInstances) val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) @@ -155,7 +155,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes var idx = 0 for (fitIntercept <- Seq(false, true)) { - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(instances2) val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 2cb1af0dee0bc..093d02ea7a14b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{BLAS, Vectors} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -28,6 +28,9 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext private var instances: RDD[Instance] = _ private var instancesConstLabel: RDD[Instance] = _ + private var instancesConstZeroLabel: RDD[Instance] = _ + private var collinearInstances: RDD[Instance] = _ + private var constantFeaturesInstances: RDD[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -58,26 +61,121 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) - } - test("two collinear features result in error with no regularization") { - val singularInstances = sc.parallelize(Seq( + /* + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + w <- c(1, 1, 1, 1) + */ + collinearInstances = sc.parallelize(Seq( Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)), Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)), Instance(4.0, 1.0, Vectors.dense(4.0, 8.0)) ), 2) - intercept[IllegalArgumentException] { - new WeightedLeastSquares( - false, regParam = 0.0, standardizeFeatures = false, - standardizeLabel = false).fit(singularInstances) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(0, 0, 0, 0) + w <- c(1, 2, 3, 4) + */ + instancesConstZeroLabel = sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + + /* + R code: + + A <- matrix(c(1, 1, 1, 1, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + constantFeaturesInstances = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(1.0, 5.0)), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(1.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(1.0, 13.0)) + ), 2) + } + + test("WLS with strong L1 regularization") { + /* + We initialize the coefficients for WLS QN solver to be weighted average of the label. Check + here that with only an intercept the model converges to bBar. + */ + val bAgg = instances.collect().foldLeft((0.0, 0.0)) { + case ((sum, weightSum), Instance(l, w, f)) => (sum + w * l, weightSum + w) } + val bBar = bAgg._1 / bAgg._2 + val wls = new WeightedLeastSquares(true, 10, 1.0, true, true) + val model = wls.fit(instances) + assert(model.intercept ~== bBar relTol 1e-6) + } - // Should not throw an exception - new WeightedLeastSquares( - false, regParam = 1.0, standardizeFeatures = false, - standardizeLabel = false).fit(singularInstances) + test("diagonal inverse of AtWA") { + /* + library(Matrix) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + w <- c(1, 2, 3, 4) + W <- Diagonal(length(w), w) + A.intercept <- cbind(A, rep.int(1, length(w))) + AtA.intercept <- t(A.intercept) %*% W %*% A.intercept + inv.intercept <- solve(AtA.intercept) + print(diag(inv.intercept)) + [1] 4.02 0.50 12.02 + + AtA <- t(A) %*% W %*% A + inv <- solve(AtA) + print(diag(inv)) + [1] 0.48336106 0.02079867 + + */ + val expectedWithIntercept = Vectors.dense(4.02, 0.50, 12.02) + val expected = Vectors.dense(0.48336106, 0.02079867) + val wlsWithIntercept = new WeightedLeastSquares(fitIntercept = true, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = WeightedLeastSquares.Cholesky) + val wlsModelWithIntercept = wlsWithIntercept.fit(instances) + val wls = new WeightedLeastSquares(false, 0.0, 0.0, true, true, + solverType = WeightedLeastSquares.Cholesky) + val wlsModel = wls.fit(instances) + + assert(expectedWithIntercept ~== wlsModelWithIntercept.diagInvAtWA relTol 1e-4) + assert(expected ~== wlsModel.diagInvAtWA relTol 1e-4) + } + + test("two collinear features") { + // Cholesky solver does not handle singular input + intercept[SingularMatrixException] { + new WeightedLeastSquares(fitIntercept = false, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky).fit(collinearInstances) + } + + // Cholesky should not throw an exception since regularization is applied + new WeightedLeastSquares(fitIntercept = false, regParam = 1.0, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky).fit(collinearInstances) + + // quasi-newton solvers should handle singular input and make correct predictions + // auto solver should try Cholesky first, then fall back to QN + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true); + solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.QuasiNewton)) { + val singularModel = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization, solverType = solver).fit(collinearInstances) + + collinearInstances.collect().foreach { case Instance(l, w, f) => + val pred = BLAS.dot(singularModel.coefficients, f) + singularModel.intercept + assert(pred ~== l absTol 1e-6) + } + } } test("WLS against lm") { @@ -100,13 +198,15 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { - for (standardization <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = standardization, - standardizeLabel = standardization).fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) - } + for (standardization <- Seq(false, true)) { + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = solver).fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } + } idx += 1 } } @@ -132,28 +232,256 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { for (standardization <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = standardization, - standardizeLabel = standardization).fit(instancesConstLabel) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = solver).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } } idx += 1 } + + // when label is constant zero, and fitIntercept is false, we should not train and get all zeros + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept = false, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = solver).fit(instancesConstZeroLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual === Vectors.dense(0.0, 0.0, 0.0)) + assert(wls.objectiveHistory === Array(0.0)) + } } test("WLS with regularization when label is constant") { // if regParam is non-zero and standardization is true, the problem is ill-defined and // an exception is thrown. - val wls = new WeightedLeastSquares( - fitIntercept = false, regParam = 0.1, standardizeFeatures = true, - standardizeLabel = true) - intercept[IllegalArgumentException]{ - wls.fit(instancesConstLabel) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept = false, regParam = 0.1, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = solver) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } } } - test("WLS against glmnet") { + test("WLS against glmnet with constant features") { + // Cholesky solver does not handle singular input with no regularization + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = WeightedLeastSquares.Cholesky) + intercept[SingularMatrixException] { + wls.fit(constantFeaturesInstances) + } + } + + // Cholesky also fails when regularization is added but we don't wish to standardize + val wls = new WeightedLeastSquares(fitIntercept = true, regParam = 0.5, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky) + intercept[SingularMatrixException] { + wls.fit(constantFeaturesInstances) + } + + /* + for (intercept in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=0.5, + standardize=T, alpha=0.0, thresh=1E-14) + print(as.vector(coef(model))) + } + [1] 0.000000 0.000000 2.235802 + [1] 9.798771 0.000000 1.365503 + */ + // should not fail when regularization and standardization are added + val expectedCholesky = Seq( + Vectors.dense(0.0, 0.0, 2.235802), + Vectors.dense(9.798771, 0.0, 1.365503) + ) + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares(fitIntercept = fitIntercept, regParam = 0.5, + elasticNetParam = 0.0, standardizeFeatures = true, + standardizeLabel = true, solverType = WeightedLeastSquares.Cholesky) + .fit(constantFeaturesInstances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expectedCholesky(idx) absTol 1e-6) + idx += 1 + } + + /* + for (intercept in c(FALSE, TRUE)) { + for (standardize in c(FALSE, TRUE)) { + for (regParams in list(c(0.0, 0.0), c(0.5, 0.0), c(0.5, 0.5), c(0.5, 1.0))) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=regParams[1], + standardize=standardize, alpha=regParams[2], thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.250857 + [1] 0.000000 0.000000 2.249784 + [1] 0.000000 0.000000 2.248709 + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.235802 + [1] 0.000000 0.000000 2.238297 + [1] 0.000000 0.000000 2.240811 + [1] 8.218905 0.000000 1.517413 + [1] 8.434286 0.000000 1.496703 + [1] 8.648497 0.000000 1.476106 + [1] 8.865672 0.000000 1.455224 + [1] 8.218905 0.000000 1.517413 + [1] 9.798771 0.000000 1.365503 + [1] 9.919095 0.000000 1.353933 + [1] 10.052804 0.000000 1.341077 + */ + val expectedQuasiNewton = Seq( + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.250857), + Vectors.dense(0.000000, 0.000000, 2.249784), + Vectors.dense(0.000000, 0.000000, 2.248709), + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.235802), + Vectors.dense(0.000000, 0.000000, 2.238297), + Vectors.dense(0.000000, 0.000000, 2.240811), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(8.434286, 0.000000, 1.496703), + Vectors.dense(8.648497, 0.000000, 1.476106), + Vectors.dense(8.865672, 0.000000, 1.455224), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(9.798771, 0.000000, 1.365503), + Vectors.dense(9.919095, 0.000000, 1.353933), + Vectors.dense(10.052804, 0.000000, 1.341077)) + + idx = 0 + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true); + (lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.QuasiNewton) + val model = wls.fit(constantFeaturesInstances) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6) + + idx += 1 + } + } + + test("WLS against glmnet with L1/ElasticNet regularization") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.1, 0.5, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + for (alpha in c(0.1, 0.5, 1.0)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=alpha, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + } + [1] 0.000000 -3.292821 2.921188 + [1] 0.000000 -3.230854 2.908484 + [1] 0.000000 -3.145586 2.891014 + [1] 0.000000 -2.919246 2.841724 + [1] 0.000000 -2.938323 2.846369 + [1] 0.000000 -2.965397 2.852838 + [1] 0.000000 -2.137858 2.684464 + [1] 0.000000 -1.680094 2.590844 + [1] 0.0000000 -0.8194631 2.4151405 + [1] 0.0000000 -0.9608375 2.4301013 + [1] 0.0000000 -0.6187922 2.3634907 + [1] 0.000000 0.000000 2.240811 + [1] 0.000000 -1.346573 2.521293 + [1] 0.0000000 -0.3680456 2.3212362 + [1] 0.000000 0.000000 2.244406 + [1] 0.000000 0.000000 2.219816 + [1] 0.000000 0.000000 2.223694 + [1] 0.00000 0.00000 2.22861 + [1] 13.5631592 3.2811513 0.3725517 + [1] 13.6953934 3.3336271 0.3497454 + [1] 13.9600276 3.4600170 0.2999941 + [1] 14.2389889 3.6589920 0.2349065 + [1] 15.2374080 4.2119643 0.0325638 + [1] 15.4 4.3 0.0 + [1] 10.442365 1.246065 1.063991 + [1] 8.9580718 0.1938471 1.4090610 + [1] 8.865672 0.000000 1.455224 + [1] 13.0430927 2.4927151 0.5741805 + [1] 13.814429 2.722027 0.455915 + [1] 16.2 3.9 0.0 + [1] 9.8904768 0.7574694 1.2110177 + [1] 9.072226 0.000000 1.435363 + [1] 9.512438 0.000000 1.393035 + [1] 13.3677796 2.1721216 0.6046132 + [1] 14.2554457 2.2285185 0.5084151 + [1] 17.2 3.4 0.0 + */ + + val expected = Seq( + Vectors.dense(0, -3.2928206726474, 2.92118822588649), + Vectors.dense(0, -3.23085414359003, 2.90848366035008), + Vectors.dense(0, -3.14558628299477, 2.89101408157209), + Vectors.dense(0, -2.91924558816421, 2.84172398097327), + Vectors.dense(0, -2.93832343383477, 2.84636891947663), + Vectors.dense(0, -2.96539689593024, 2.85283836322185), + Vectors.dense(0, -2.13785756976542, 2.68446351346705), + Vectors.dense(0, -1.68009377560774, 2.59084422793154), + Vectors.dense(0, -0.819463123385533, 2.41514053108346), + Vectors.dense(0, -0.960837488151064, 2.43010130999756), + Vectors.dense(0, -0.618792151647599, 2.36349074148962), + Vectors.dense(0, 0, 2.24081114726441), + Vectors.dense(0, -1.34657309253953, 2.52129296638512), + Vectors.dense(0, -0.368045602821844, 2.32123616258871), + Vectors.dense(0, 0, 2.24440619621343), + Vectors.dense(0, 0, 2.21981559944924), + Vectors.dense(0, 0, 2.22369447413621), + Vectors.dense(0, 0, 2.22861024633605), + Vectors.dense(13.5631591827557, 3.28115132060568, 0.372551747695477), + Vectors.dense(13.6953934007661, 3.3336271417751, 0.349745414969587), + Vectors.dense(13.960027608754, 3.46001702257532, 0.29999407173994), + Vectors.dense(14.2389889013085, 3.65899196445023, 0.234906458633754), + Vectors.dense(15.2374079667397, 4.21196428071551, 0.0325637953681963), + Vectors.dense(15.4, 4.3, 0), + Vectors.dense(10.4423647474653, 1.24606545153166, 1.06399080283378), + Vectors.dense(8.95807177856822, 0.193847088148233, 1.4090609658784), + Vectors.dense(8.86567164179104, 0, 1.45522388059702), + Vectors.dense(13.0430927453034, 2.49271514356687, 0.574180477650271), + Vectors.dense(13.8144287399675, 2.72202744354555, 0.455915035859752), + Vectors.dense(16.2, 3.9, 0), + Vectors.dense(9.89047681835741, 0.757469417613661, 1.21101772561685), + Vectors.dense(9.07222551185964, 0, 1.43536293155196), + Vectors.dense(9.51243781094527, 0, 1.39303482587065), + Vectors.dense(13.3677796362763, 2.17212164262107, 0.604613180623227), + Vectors.dense(14.2554457236073, 2.22851848830683, 0.508415124978748), + Vectors.dense(17.2, 3.4, 0) + ) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.1, 0.5, 1.0); + standardization <- Seq(false, true); + elasticNetParam <- Seq(0.1, 0.5, 1.0)) { + val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.Auto) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet with L2 regularization") { /* R code: @@ -200,12 +528,14 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true); regParam <- Seq(0.0, 0.1, 1.0); - standardizeFeatures <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) - .fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + standardization <- Seq(false, true)) { + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = true, solverType = solver) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } idx += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala index 5eaef9aabda50..3bb760f2ecc1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala @@ -54,7 +54,7 @@ class MLSerDeSuite extends SparkFunSuite { assert(matrix === nm) // Test conversion for empty matrix - val empty = Array[Double]() + val empty = Array.empty[Double] val emptyMatrix = Matrices.dense(0, 0, empty) val ne = MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] assert(emptyMatrix == ne) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 1c94ec67d79d1..c0e8afbf5e346 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -57,7 +57,7 @@ class LinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() val r = new Random(seed) - // When feature size is larger than 4096, normal optimizer is choosed + // When feature size is larger than 4096, normal optimizer is chosen // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 datasetWithSparseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -155,6 +155,42 @@ class LinearRegressionSuite assert(model.numFeatures === numFeatures) } + test("linear regression handles singular matrices") { + // check for both constant columns with intercept (zero std) and collinear + val singularDataConstantColumn = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(1.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(1.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(1.0, 13.0)) + ), 2).toDF() + + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver).setFitIntercept(true) + val model = trainer.fit(singularDataConstantColumn) + // to make it clear that WLS did not solve analytically + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + assert(model.summary.objectiveHistory !== Array(0.0)) + } + + val singularDataCollinearFeatures = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(10.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(14.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(22.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(26.0, 13.0)) + ), 2).toDF() + + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver).setFitIntercept(true) + val model = trainer.fit(singularDataCollinearFeatures) + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + assert(model.summary.objectiveHistory !== Array(0.0)) + } + } + test("linear regression with intercept without regularization") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer1 = new LinearRegression().setSolver(solver) @@ -233,12 +269,12 @@ class LinearRegressionSuite as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val coefficientsWithourInterceptR = Vectors.dense(4.70011, 7.19943) + val coefficientsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept1.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + assert(modelWithoutIntercept1.coefficients ~= coefficientsWithoutInterceptR relTol 1E-3) assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept2.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.coefficients ~= coefficientsWithoutInterceptR relTol 1E-3) } } @@ -249,55 +285,47 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) .setSolver(solver).setStandardization(false) - // Normal optimizer is not supported with only L1 regularization case. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", - alpha = 1.0, lambda = 0.57 )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.242284 - as.numeric.d1.V2. 4.019605 - as.numeric.d1.V3. 6.679538 - */ - val interceptR1 = 6.242284 - val coefficientsR1 = Vectors.dense(4.019605, 6.679538) - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.416948 - as.numeric.data.V2. 3.893869 - as.numeric.data.V3. 6.724286 - */ - val interceptR2 = 6.416948 - val coefficientsR2 = Vectors.dense(3.893869, 6.724286) - - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", + alpha = 1.0, lambda = 0.57 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 + */ + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.416948 + as.numeric.data.V2. 3.893869 + as.numeric.data.V3. 6.724286 + */ + val interceptR2 = 6.416948 + val coefficientsR2 = Vectors.dense(3.893869, 6.724286) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -309,56 +337,48 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) .setFitIntercept(false).setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with only L1 regularization case. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, intercept=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.272927 - as.numeric.data.V3. 4.782604 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - - assert(model1.intercept ~== interceptR1 absTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, intercept=FALSE, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.207817 - as.numeric.data.V3. 4.775780 - */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - - assert(model2.intercept ~== interceptR2 absTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -471,56 +491,48 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) .setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with non-zero elasticnet parameter. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6 )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 5.689855 - as.numeric.d1.V2. 3.661181 - as.numeric.d1.V3. 6.000274 - */ - val interceptR1 = 5.689855 - val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 - standardize=FALSE)) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.113890 - as.numeric.d1.V2. 3.407021 - as.numeric.d1.V3. 6.152512 - */ - val interceptR2 = 6.113890 - val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - - assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 + */ + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + standardize=FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 + */ + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -532,57 +544,49 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) .setFitIntercept(false).setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with non-zero elasticnet parameter. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6, intercept=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.d1.V2. 5.643748 - as.numeric.d1.V3. 4.331519 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - - assert(model1.intercept ~== interceptR1 absTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6, intercept=FALSE, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.d1.V2. 5.455902 - as.numeric.d1.V3. 4.312266 - - */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - - assert(model2.intercept ~== interceptR2 absTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -757,7 +761,8 @@ class LinearRegressionSuite assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) - // Normal solver uses "WeightedLeastSquares". This algorithm does not generate + // Normal solver uses "WeightedLeastSquares". If no regularization is applied or only L2 + // regularization is applied, this algorithm uses a direct solver and does not generate an // objective history because it does not run through iterations. if (solver == "l-bfgs") { // Objective function should be monotonically decreasing for linear regression @@ -776,7 +781,7 @@ class LinearRegressionSuite val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => assert(x._1 ~== x._2 absTol 1E-4) } - model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + model.summary.coefficientStandardErrors.zip(seCoefR).foreach { x => assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } @@ -950,6 +955,20 @@ class LinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + + val modelWithL1 = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .setRegParam(0.5) + .setElasticNetParam(1.0) + .fit(datasetWithWeight) + + assert(modelWithL1.summary.objectiveHistory !== Array(0.0)) + assert( + modelWithL1.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) } test("linear regression summary with weighted samples and w/o intercept by normal solver") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 499d386e66413..3bded9c01760a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -154,10 +154,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(0, 0, 0).map(_.toDouble) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array[Double]()) + assert(splits === Array.empty[Double]) val splitsEmpty = RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) - assert(splitsEmpty === Array[Double]()) + assert(splitsEmpty === Array.empty[Double]) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 0eb839f20c003..5f85c0d65ff2d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -72,7 +72,7 @@ class PythonMLLibAPISuite extends SparkFunSuite { assert(matrix === nm) // Test conversion for empty matrix - val empty = Array[Double]() + val empty = Array.empty[Double] val emptyMatrix = Matrices.dense(0, 0, empty) val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] assert(emptyMatrix == ne) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 2d35b312083c0..48bd41dc3e3bf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -29,6 +29,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} + private val seed = 42 + test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(1.0, 2.0, 6.0), @@ -38,7 +40,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 3.0, 4.0) - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) @@ -50,44 +52,72 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( - data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) + data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) } - test("no distinct points") { + test("fewer distinct points than clusters") { val data = sc.parallelize( Array( Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(1.0, 2.0, 3.0)), 2) - val center = Vectors.dense(1.0, 2.0, 3.0) - // Make sure code runs. - var model = KMeans.train(data, k = 2, maxIterations = 1) - assert(model.clusterCenters.size === 2) - } + var model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "random") + assert(model.clusterCenters.length === 1) - test("more clusters than points") { - val data = sc.parallelize( - Array( - Vectors.dense(1.0, 2.0, 3.0), - Vectors.dense(1.0, 3.0, 4.0)), - 2) + model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "k-means||") + assert(model.clusterCenters.length === 1) + } - // Make sure code runs. - var model = KMeans.train(data, k = 3, maxIterations = 1) - assert(model.clusterCenters.size === 3) + test("unique cluster centers") { + val rng = new Random(seed) + val numDistinctPoints = 10 + val points = (0 until numDistinctPoints).map(i => Vectors.dense(Array.fill(3)(rng.nextDouble))) + val data = sc.parallelize(points.flatMap(Array.fill(1 + rng.nextInt(3))(_)), 2) + val normedData = data.map(new VectorWithNorm(_)) + + // less centers than k + val km = new KMeans().setK(50) + .setMaxIterations(5) + .setInitializationMode("k-means||") + .setInitializationSteps(10) + .setSeed(seed) + val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + assert(initialCenters.length === initialCenters.distinct.length) + assert(initialCenters.length <= numDistinctPoints) + + val model = km.run(data) + val finalCenters = model.clusterCenters + assert(finalCenters.length === finalCenters.distinct.length) + + // run local k-means + val k = 10 + val km2 = new KMeans().setK(k) + .setMaxIterations(5) + .setInitializationMode("k-means||") + .setInitializationSteps(10) + .setSeed(seed) + val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + assert(initialCenters2.length === initialCenters2.distinct.length) + assert(initialCenters2.length === k) + + val model2 = km2.run(data) + val finalCenters2 = model2.clusterCenters + assert(finalCenters2.length === finalCenters2.distinct.length) + + val km3 = new KMeans().setK(k) + .setMaxIterations(5) + .setInitializationMode("random") + .setSeed(seed) + val model3 = km3.run(data) + val finalCenters3 = model3.clusterCenters + assert(finalCenters3.length === finalCenters3.distinct.length) } test("deterministic initialization") { @@ -97,12 +127,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { // Create three deterministic models and compare cluster means - val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, - initializationMode = initMode, seed = 42) + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, + initializationMode = initMode, seed = seed) val centers1 = model1.clusterCenters - val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, - initializationMode = initMode, seed = 42) + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, + initializationMode = initMode, seed = seed) val centers2 = model2.clusterCenters centers1.zip(centers2).foreach { case (c1, c2) => @@ -119,7 +149,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { ) val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points val center = Vectors.dense(1.0, 3.0, 4.0) @@ -134,17 +164,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, - initializationMode = K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) } @@ -165,7 +188,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { data.persist() - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) @@ -179,17 +202,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, - initializationMode = K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() @@ -230,11 +246,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(rdd, k = 5, maxIterations = 10) assert(model.clusterCenters.sortBy(VectorWithCompare(_)) .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) - - // Neither should more runs - model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(model.clusterCenters.sortBy(VectorWithCompare(_)) - .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { @@ -250,7 +261,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { // Two iterations are sufficient no matter where the initial centers are. - val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode) + val model = KMeans.train(rdd, k = 2, maxIterations = 2, initMode) val predicts = model.predict(rdd).collect() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index f316c67234f18..142d1e9812ef1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -36,6 +36,9 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) val delta = 0.0000001 + val tpRate0 = 2.0 / (2 + 2) + val tpRate1 = 3.0 / (3 + 1) + val tpRate2 = 1.0 / (1 + 0) val fpRate0 = 1.0 / (9 - 4) val fpRate1 = 1.0 / (9 - 4) val fpRate2 = 1.0 / (9 - 1) @@ -53,6 +56,9 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) + assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) + assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) + assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) @@ -75,6 +81,8 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(math.abs(metrics.accuracy - metrics.recall) < delta) assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedTruePositiveRate - + ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) assert(math.abs(metrics.weightedFalsePositiveRate - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) assert(math.abs(metrics.weightedPrecision - diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index f3b19aeb42f84..a660492c7ae59 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -47,7 +47,7 @@ class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( Seq((Array(0.0, 1.0), Array(0.0, 2.0)), (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), + (Array.empty[Double], Array(0.0)), (Array(2.0), Array(2.0)), (Array(2.0, 0.0), Array(2.0, 0.0)), (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 8e9d910e646c9..f334be2c2ba83 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -28,7 +28,7 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq( (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)), - (Array(1, 2, 3, 4, 5), Array[Int]()) + (Array(1, 2, 3, 4, 5), Array.empty[Int]) ), 2) val eps = 1.0E-5 @@ -55,7 +55,7 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val predictionAndLabels = sc.parallelize( Seq( (Array(1, 6, 2), Array(1, 2, 3, 4, 5)), - (Array[Int](), Array(1, 2, 3)) + (Array.empty[Int], Array(1, 2, 3)) ), 2) val eps = 1.0E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ec23a4aa7364d..ac702b4b7c69e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,10 +54,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), - LabeledPoint(1.0, Vectors.dense(Array(6.0))), - LabeledPoint(1.0, Vectors.dense(Array(8.0))), - LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index d0c4dd28e14ee..563756907d201 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -289,7 +289,7 @@ class MatricesSuite extends SparkFunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) + val deHorz2 = Matrices.horzcat(Array.empty[Matrix]) assert(deHorz1.numRows === 3) assert(spHorz2.numRows === 3) @@ -343,7 +343,7 @@ class MatricesSuite extends SparkFunSuite { val deVert1 = Matrices.vertcat(Array(deMat1, deMat3)) val spVert2 = Matrices.vertcat(Array(spMat1, deMat3)) val spVert3 = Matrices.vertcat(Array(deMat1, spMat3)) - val deVert2 = Matrices.vertcat(Array[Matrix]()) + val deVert2 = Matrices.vertcat(Array.empty[Matrix]) assert(deVert1.numRows === 5) assert(spVert2.numRows === 5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 1aff44480aac9..3fcf1cf2c2635 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -110,9 +110,9 @@ class TestingUtilsSuite extends SparkFunSuite { assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) assert(Vectors.dense(Array(3.1)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array.empty[Double]) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) // Should throw exception with message when test fails. intercept[TestFailedException]( @@ -125,7 +125,7 @@ class TestingUtilsSuite extends SparkFunSuite { Vectors.dense(Array(3.1)) ~== Vectors.dense(Array(3.535, 3.534)) relTol 0.01) intercept[TestFailedException]( - Vectors.dense(Array[Double]()) ~== Vectors.dense(Array(3.135)) relTol 0.01) + Vectors.dense(Array.empty[Double]) ~== Vectors.dense(Array(3.135)) relTol 0.01) // Comparing against zero should fail the test and throw exception with message // saying that the relative error is meaningless in this situation. @@ -145,7 +145,7 @@ class TestingUtilsSuite extends SparkFunSuite { assert(Vectors.dense(Array(3.1)) !~== Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) - assert(Vectors.dense(Array[Double]()) !~== + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) } @@ -176,14 +176,14 @@ class TestingUtilsSuite extends SparkFunSuite { assert(!(Vectors.dense(Array(3.1)) ~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5)) - assert(Vectors.dense(Array[Double]()) !~= + assert(Vectors.dense(Array.empty[Double]) !~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5) - assert(!(Vectors.dense(Array[Double]()) ~= + assert(!(Vectors.dense(Array.empty[Double]) ~= Vectors.dense(Array(3.1 + 1E-6, 3.5 + 2E-7)) absTol 1E-5)) - assert(Vectors.dense(Array[Double]()) ~= - Vectors.dense(Array[Double]()) absTol 1E-5) + assert(Vectors.dense(Array.empty[Double]) ~= + Vectors.dense(Array.empty[Double]) absTol 1E-5) // Should throw exception with message when test fails. intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== @@ -195,7 +195,7 @@ class TestingUtilsSuite extends SparkFunSuite { intercept[TestFailedException](Vectors.dense(Array(3.1)) ~== Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7)) absTol 1E-6) - intercept[TestFailedException](Vectors.dense(Array[Double]()) ~== + intercept[TestFailedException](Vectors.dense(Array.empty[Double]) ~== Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7)) absTol 1E-6) // Comparisons of two sparse vectors @@ -214,7 +214,7 @@ class TestingUtilsSuite extends SparkFunSuite { assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-6, 2.4)) !~== Vectors.sparse(1, Array(0), Array(3.1)) absTol 1E-3) - assert(Vectors.sparse(0, Array[Int](), Array[Double]()) !~== + assert(Vectors.sparse(0, Array.empty[Int], Array.empty[Double]) !~== Vectors.sparse(1, Array(0), Array(3.1)) absTol 1E-3) // Comparisons of a dense vector and a sparse vector @@ -230,14 +230,14 @@ class TestingUtilsSuite extends SparkFunSuite { assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== Vectors.dense(Array(3.1)) absTol 1E-6) - assert(Vectors.dense(Array[Double]()) !~== + assert(Vectors.dense(Array.empty[Double]) !~== Vectors.sparse(3, Array(0, 2), Array(0, 2.4)) absTol 1E-6) assert(Vectors.sparse(1, Array(0), Array(3.1)) !~== Vectors.dense(Array(3.1, 3.2)) absTol 1E-6) assert(Vectors.dense(Array(3.1)) !~== - Vectors.sparse(0, Array[Int](), Array[Double]()) absTol 1E-6) + Vectors.sparse(0, Array.empty[Int], Array.empty[Double]) absTol 1E-6) } test("Comparing Matrices using absolute error.") { diff --git a/pom.xml b/pom.xml index b371b913389b7..c34db84259c5b 100644 --- a/pom.xml +++ b/pom.xml @@ -134,7 +134,7 @@ 1.2.1 10.12.1.1 - 1.9.0-palantir2 + 1.9.0-palantir3 1.6.0 9.2.16.v20160414 3.1.0 @@ -168,7 +168,7 @@ 2.6 - 3.3.2 + 3.5 3.2.10 3.0.0 2.22.2 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ae72d37a0b61c..350b144f8294b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -56,10 +56,37 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + + // [SPARK-17731][SQL][Streaming] Metrics for structured streaming + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceStatus.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.SourceStatus.offsetDesc"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.status"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SinkStatus.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.queryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.queryInfo"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.queryInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStarted"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), + // [SPARK-17338][SQL] add global temp view ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), + + // [SPARK-18034] Upgrade to MiMa 0.1.11 to fix flakiness. + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.aggregationDepth"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.getAggregationDepth"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_=") ) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b7187069947ec..369de14c949c1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -736,9 +736,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) diff --git a/project/plugins.sbt b/project/plugins.sbt index 60ef408df7672..66db01a0ae23b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,7 +8,7 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.11") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") diff --git a/python/docs/Makefile b/python/docs/Makefile index de86e97d862f0..5e4cfb8ab6fe3 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.3-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.3-src.zip b/python/lib/py4j-0.10.3-src.zip deleted file mode 100644 index bc54f33af1515..0000000000000 Binary files a/python/lib/py4j-0.10.3-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip new file mode 100644 index 0000000000000..8c3829e328726 Binary files /dev/null and b/python/lib/py4j-0.10.4-src.zip differ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3f763a10d4066..d9ff356b9403a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -758,20 +758,21 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", - numTrees=20, featureSubsetStrategy="auto", seed=None): + numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ - numTrees=20, featureSubsetStrategy="auto", seed=None) + numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0) """ super(RandomForestClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini", numTrees=20, featureSubsetStrategy="auto") + impurity="gini", numTrees=20, featureSubsetStrategy="auto", + subsamplingRate=1.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -781,13 +782,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="gini", numTrees=20, featureSubsetStrategy="auto"): + impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ - impurity="gini", numTrees=20, featureSubsetStrategy="auto") + impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) Sets params for linear classification. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 1fe8772da772a..7aa16fa5b90f2 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -22,6 +22,7 @@ from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.common import inherit_doc +from pyspark.ml.util import JavaMLReadable, JavaMLWritable __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', 'MulticlassClassificationEvaluator'] @@ -103,7 +104,8 @@ def isLargerBetter(self): @inherit_doc -class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) 0.83... + >>> bce_path = temp_path + "/bce" + >>> evaluator.save(bce_path) + >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path) + >>> str(evaluator2.getRawPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -172,7 +179,8 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", @inherit_doc -class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) 2.649... + >>> re_path = temp_path + "/re" + >>> evaluator.save(re_path) + >>> evaluator2 = RegressionEvaluator.load(re_path) + >>> str(evaluator2.getPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -244,7 +257,8 @@ def setParams(self, predictionCol="prediction", labelCol="label", @inherit_doc -class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) 0.66... + >>> mce_path = temp_path + "/mce" + >>> evaluator.save(mce_path) + >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path) + >>> str(evaluator2.getPredictionCol()) + 'prediction' .. versionadded:: 1.5.0 """ @@ -311,19 +330,27 @@ def setParams(self, predictionCol="prediction", labelCol="label", if __name__ == "__main__": import doctest + import tempfile + import pyspark.ml.evaluation from pyspark.sql import SparkSession - globs = globals().copy() + globs = pyspark.ml.evaluation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: spark = SparkSession.builder\ .master("local[2]")\ .appName("ml.evaluation tests")\ .getOrCreate() - sc = spark.sparkContext - globs['sc'] = sc globs['spark'] = spark - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) - spark.stop() + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616ec..94afe82a36472 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,11 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. - It is possible that the number of buckets used will be less than this value, for example, if - there are too few distinct values of the input to create enough distinct quantiles. Note also - that NaN values are handled specially and placed into their own bucket. For example, if 4 - buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in - a special bucket(4). The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the @@ -2494,21 +2489,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM formula = Param(Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString) + forceIndexLabel = Param(Params._dummy(), "forceIndexLabel", + "Force to index label whether it is numeric or string", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, formula=None, featuresCol="features", labelCol="label"): + def __init__(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - __init__(self, formula=None, featuresCol="features", labelCol="label") + __init__(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self._setDefault(forceIndexLabel=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") - def setParams(self, formula=None, featuresCol="features", labelCol="label"): + def setParams(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - setParams(self, formula=None, featuresCol="features", labelCol="label") + setParams(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) Sets params for RFormula. """ kwargs = self.setParams._input_kwargs @@ -2528,6 +2532,20 @@ def getFormula(self): """ return self.getOrDefault(self.formula) + @since("2.1.0") + def setForceIndexLabel(self, value): + """ + Sets the value of :py:attr:`forceIndexLabel`. + """ + return self._set(forceIndexLabel=value) + + @since("2.1.0") + def getForceIndexLabel(self): + """ + Gets the value of :py:attr:`forceIndexLabel`. + """ + return self.getOrDefault(self.forceIndexLabel) + def _create_model(self, java_model): return RFormulaModel(java_model) @@ -2569,9 +2587,9 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") >>> model = selector.fit(df) >>> model.transform(df).head().selectedFeatures - DenseVector([1.0]) + DenseVector([18.0]) >>> model.selectedFeatures - [3] + [2] >>> chiSqSelectorPath = temp_path + "/chi-sq-selector" >>> selector.save(chiSqSelectorPath) >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 55d38033ef72a..9233d2e7e1a77 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -594,7 +594,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -828,7 +828,7 @@ def featureImportances(self): @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, RandomForestParams, TreeRegressorParams, HasCheckpointInterval, - JavaMLWritable, JavaMLReadable): + JavaMLWritable, JavaMLReadable, HasVarianceCol): """ `Random Forest `_ learning algorithm for regression. @@ -876,13 +876,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto"): + featureSubsetStrategy="auto", varianceCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto") + featureSubsetStrategy="auto", varianceCol=None) """ super(RandomForestRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -900,13 +900,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto"): + featureSubsetStrategy="auto", varianceCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto") + featureSubsetStrategy="auto", varianceCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e233549850888..9d46cc3b4ae64 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -477,6 +477,22 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + class HasInducedError(Params): diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4aea81840a162..50ef7c7901c2c 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -288,15 +288,15 @@ class ChiSqSelector(object): ... ] >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> data = [ ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})), diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0e2ae19ca39aa..2de2c2fd1a60b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2029,12 +2029,12 @@ def coalesce(self, numPartitions, shuffle=False): [[1, 2, 3, 4, 5]] """ if shuffle: - # In Scala's repartition code, we will distribute elements evenly across output - # partitions. However, the RDD from Python is serialized as a single binary data, - # so the distribution fails and produces highly skewed partitions. We need to - # convert it to a RDD of java object before repartitioning. - data_java_rdd = self._to_java_object_rdd().coalesce(numPartitions, shuffle) - jrdd = self.ctx._jvm.SerDeUtil.javaToPython(data_java_rdd) + # Decrease the batch size in order to distribute evenly the elements across output + # partitions. Otherwise, repartition will possibly produce highly skewed partitions. + batchSize = min(10, self.ctx._batchSize or 1024) + ser = BatchedSerializer(PickleSerializer(), batchSize) + selfCopy = self._reserialize(ser) + jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) else: jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8264dcf8a97d2..de4c335ad2752 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,7 +28,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, StringType +from pyspark.sql.types import IntegerType, Row, StringType from pyspark.sql.utils import install_exception_handler __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] @@ -202,6 +202,32 @@ def registerFunction(self, name, f, returnType=StringType()): """ self.sparkSession.catalog.registerFunction(name, f, returnType) + @ignore_unicode_prefix + @since(2.1) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a java UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + :param name: name of the UDF + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> sqlContext.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> sqlContext.sql("SELECT javaStringLength('test')").collect() + [Row(UDF(test)=4)] + >>> sqlContext.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF(test)=4)] + + """ + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ce277eb204d13..29710acf54c4f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -407,24 +407,48 @@ def foreachPartition(self, f): @since(1.3) def cache(self): - """ Persists with the default storage level (C{MEMORY_ONLY}). + """Persists the :class:`DataFrame` with the default storage level (C{MEMORY_AND_DISK}). + + .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True self._jdf.cache() return self @since(1.3) - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): - """Sets the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY}). + def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK): + """Sets the storage level to persist the contents of the :class:`DataFrame` across + operations after the first time it is computed. This can only be used to assign + a new storage level if the :class:`DataFrame` does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_AND_DISK}). + + .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self + @property + @since(2.1) + def storageLevel(self): + """Get the :class:`DataFrame`'s current storage level. + + >>> df.storageLevel + StorageLevel(False, False, False, False, 1) + >>> df.cache().storageLevel + StorageLevel(True, True, False, True, 1) + >>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevel + StorageLevel(True, False, False, False, 2) + """ + java_storage_level = self._jdf.storageLevel() + storage_level = StorageLevel(java_storage_level.useDisk(), + java_storage_level.useMemory(), + java_storage_level.useOffHeap(), + java_storage_level.deserialized(), + java_storage_level.replication()) + return storage_level + @since(1.3) def unpersist(self, blocking=False): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from @@ -626,6 +650,25 @@ def alias(self, alias): assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + @ignore_unicode_prefix + @since(2.1) + def crossJoin(self, other): + """Returns the cartesian product with another :class:`DataFrame`. + + :param other: Right side of the cartesian product. + + >>> df.select("age", "name").collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df2.select("name", "height").collect() + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)] + >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect() + [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85), + Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)] + """ + + jdf = self._jdf.crossJoin(other._jdf) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def join(self, other, on=None, how=None): @@ -666,14 +709,11 @@ def join(self, other, on=None, how=None): on = self._jseq(on) else: assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] + on = reduce(lambda x, y: x.__and__(y), on) on = on._jc if on is None and how is None: - jdf = self._jdf.crossJoin(other._jdf) + jdf = self._jdf.join(other._jdf) else: if how is None: how = "inner" diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 91c2b17049fa1..bc786ef95ed03 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -160,8 +160,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ - Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects - (one object per record) and returns the result as a :class`DataFrame`. + Loads a JSON file (`JSON Lines text format or newline-delimited JSON + <[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per + record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 4e438fd5bee22..559647bbabf67 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -189,6 +189,303 @@ def resetTerminated(self): self._jsqm.resetTerminated() +class StreamingQueryStatus(object): + """A class used to report information about the progress of a StreamingQuery. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jsqs): + self._jsqs = jsqs + + def __str__(self): + """ + Pretty string of this query status. + + >>> print(sqs) + Status of query 'query' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + triggerId: 5 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [#1, -] + """ + return self._jsqs.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def name(self): + """ + Name of the query. This name is unique across all active queries. + + >>> sqs.name + u'query' + """ + return self._jsqs.name() + + @property + @since(2.1) + def id(self): + """ + Id of the query. This id is unique across all queries that have been started in + the current process. + + >>> int(sqs.id) + 1 + """ + return self._jsqs.id() + + @property + @since(2.1) + def timestamp(self): + """ + Timestamp (ms) of when this query was generated. + + >>> int(sqs.timestamp) + 123 + """ + return self._jsqs.timestamp() + + @property + @since(2.1) + def inputRate(self): + """ + Current total rate (rows/sec) at which data is being generated by all the sources. + + >>> sqs.inputRate + 15.5 + """ + return self._jsqs.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from all the sources. + + >>> sqs.processingRate + 23.5 + """ + return self._jsqs.processingRate() + + @property + @since(2.1) + def latency(self): + """ + Current average latency between the data being available in source and the sink + writing the corresponding output. + + >>> sqs.latency + 345.0 + """ + if (self._jsqs.latency().nonEmpty()): + return self._jsqs.latency().get() + else: + return None + + @property + @ignore_unicode_prefix + @since(2.1) + def sourceStatuses(self): + """ + Current statuses of the sources as a list. + + >>> len(sqs.sourceStatuses) + 1 + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return [SourceStatus(ss) for ss in self._jsqs.sourceStatuses()] + + @property + @ignore_unicode_prefix + @since(2.1) + def sinkStatus(self): + """ + Current status of the sink. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return SinkStatus(self._jsqs.sinkStatus()) + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.triggerDetails + {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + u'isDataPresentInTrigger': u'true'} + """ + return self._jsqs.triggerDetails() + + +class SourceStatus(object): + """ + Status and metrics of a streaming Source. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sourceStatuses[0]) + Status of source MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offset if known. + + >>> sqs.sourceStatuses[0].offsetDesc + u'#0' + """ + return self._jss.offsetDesc() + + @property + @since(2.1) + def inputRate(self): + """ + Current rate (rows/sec) at which data is being generated by the source. + + >>> sqs.sourceStatuses[0].inputRate + 15.5 + """ + return self._jss.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from the source. + + >>> sqs.sourceStatuses[0].processingRate + 23.5 + """ + return self._jss.processingRate() + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.sourceStatuses[0].triggerDetails + {u'numRows.input.source': u'100', u'latency.getOffset.source': u'10', + u'latency.getBatch.source': u'20'} + """ + return self._jss.triggerDetails() + + +class SinkStatus(object): + """ + Status and metrics of a streaming Sink. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sinkStatus) + Status of sink MySink + Committed offsets: [#1, -] + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offsets up to which data has been written by the sink. + + >>> sqs.sinkStatus.offsetDesc + u'[#1, -]' + """ + return self._jss.offsetDesc() + + class Trigger(object): """Used to indicate how often results should be produced by a :class:`StreamingQuery`. @@ -343,7 +640,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ - Loads a JSON file stream (one object per line) and returns a :class`DataFrame`. + Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON + <[http://jsonlines.org/>`_) and returns a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -753,11 +1051,14 @@ def _test(): globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) globs['df'] = \ globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') + globs['sqs'] = StreamingQueryStatus( + spark.sparkContext._jvm.org.apache.spark.sql.streaming.StreamingQueryStatus.testStatus()) (failure_count, test_count) = doctest.testmod( pyspark.sql.streaming, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['spark'].stop() + if failure_count: exit(-1) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 51d5e7ab0568e..3d46b852c52e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1466,7 +1466,7 @@ def test_functions_broadcast(self): self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) # no join key -- should not be a broadcast join - plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) # planner should not crash without a join @@ -1514,6 +1514,19 @@ def test_invalid_join_method(self): df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + # Cartesian products require cross join syntax + def test_require_cross(self): + from pyspark.sql.functions import broadcast + + df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + + # joins without conditions require cross join syntax + self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) + + # works with crossJoin + self.assertEqual(1, df1.crossJoin(df2).count()) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index b7284487c511d..f2d9e6b568a9b 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 59823571124f1..061019a55e997 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -27,6 +27,7 @@ # SPARK_PID_DIR The pid files are stored. /tmp by default. # SPARK_IDENT_STRING A string representing this instance of spark. $USER by default # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. +# SPARK_NO_DAEMONIZE If set, will run the proposed command in the foreground. It will not output a PID file. ## usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|status) " @@ -122,6 +123,35 @@ if [ "$SPARK_NICENESS" = "" ]; then export SPARK_NICENESS=0 fi +execute_command() { + local command="$@" + if [ -z ${SPARK_NO_DAEMONIZE+set} ]; then + nohup -- $command >> $log 2>&1 < /dev/null & + newpid="$!" + + echo "$newpid" > "$pid" + + # Poll for up to 5 seconds for the java process to start + for i in {1..10} + do + if [[ $(ps -p "$newpid" -o comm=) =~ "java" ]]; then + break + fi + sleep 0.5 + done + + sleep 2 + # Check if the process has died; in that case we'll tail the log so the user can see + if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then + echo "failed to launch $command:" + tail -2 "$log" | sed 's/^/ /' + echo "full log in $log" + fi + else + $command + fi +} + run_command() { mode="$1" shift @@ -146,13 +176,11 @@ run_command() { case "$mode" in (class) - nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & - newpid="$!" + execute_command nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-class $command $@ ;; (submit) - nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & - newpid="$!" + execute_command nice -n "$SPARK_NICENESS" bash "${SPARK_HOME}"/bin/spark-submit --class $command $@ ;; (*) @@ -161,24 +189,6 @@ run_command() { ;; esac - echo "$newpid" > "$pid" - - #Poll for up to 5 seconds for the java process to start - for i in {1..10} - do - if [[ $(ps -p "$newpid" -o comm=) =~ "java" ]]; then - break - fi - sleep 0.5 - done - - sleep 2 - # Check if the process has died; in that case we'll tail the log so the user can see - if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then - echo "failed to launch $command:" - tail -2 "$log" | sed 's/^/ /' - echo "full log in $log" - fi } case $option in diff --git a/sbin/start-master.sh b/sbin/start-master.sh index d970fcc45e2c1..97ee32159b6de 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -48,7 +48,14 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST=`hostname -f` + case `uname` in + (SunOS) + SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; + (*) + SPARK_MASTER_HOST="`hostname -f`" + ;; + esac fi if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index ef65fb9539146..ecaad7ad09634 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -34,7 +34,14 @@ if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then fi if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then - SPARK_MESOS_DISPATCHER_HOST=`hostname -f` + case `uname` in + (SunOS) + SPARK_MESOS_DISPATCHER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; + (*) + SPARK_MESOS_DISPATCHER_HOST="`hostname -f`" + ;; + esac fi if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 7d8871251f81b..f5269df523dac 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -32,7 +32,14 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST="`hostname -f`" + case `uname` in + (SunOS) + SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; + (*) + SPARK_MASTER_HOST="`hostname -f`" + ;; + esac fi # Launch the slaves diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7fe0697202cd1..81d57d723a720 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -200,6 +200,7 @@ This file is divided into 3 sections: // scalastyle:off awaitresult Await.result(...) // scalastyle:on awaitresult + If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead. ]]> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index f3003306acc6d..7defb9df862c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -42,6 +42,13 @@ class AnalysisException protected[sql] ( } override def getMessage: String = { + val planAnnotation = plan.map(p => s";\n$p").getOrElse("") + getSimpleMessage + planAnnotation + } + + // Outputs an exception without the logical plan. + // For testing only + def getSimpleMessage: String = { val lineAnnotation = line.map(l => s" line $l").getOrElse("") val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") s"$message;$lineAnnotation$positionAnnotation" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index f542f5cf40506..5b9161551a7af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -199,34 +199,14 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) - override def toCatalystImpl(scalaValue: Any): MapData = scalaValue match { - case m: Map[_, _] => - val length = m.size - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - for ((key, value) <- m) { - convertedKeys(i) = keyConverter.toCatalyst(key) - convertedValues(i) = valueConverter.toCatalyst(value) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) - - case jmap: JavaMap[_, _] => - val length = jmap.size() - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - val iter = jmap.entrySet.iterator - while (iter.hasNext) { - val entry = iter.next() - convertedKeys(i) = keyConverter.toCatalyst(entry.getKey) - convertedValues(i) = valueConverter.toCatalyst(entry.getValue) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) + override def toCatalystImpl(scalaValue: Any): MapData = { + val keyFunction = (k: Any) => keyConverter.toCatalyst(k) + val valueFunction = (k: Any) => valueConverter.toCatalyst(k) + + scalaValue match { + case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) + case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + } } override def toScala(catalystValue: MapData): Map[Any, Any] = { @@ -433,18 +413,11 @@ object CatalystTypeConverters { case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) - case m: Map[_, _] => - val length = m.size - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - for ((key, value) <- m) { - convertedKeys(i) = convertToCatalyst(key) - convertedValues(i) = convertToCatalyst(value) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) + case map: Map[_, _] => + ArrayBasedMapData( + map, + (key: Any) => convertToCatalyst(key), + (value: Any) => convertToCatalyst(value)) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e6f61b00ebd70..04f0cfce883f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -59,7 +59,7 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7923cfce82100..31c6e5def143b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -628,7 +628,7 @@ object ScalaReflection extends ScalaReflection { /* * Retrieves the runtime class corresponding to the provided type. */ - def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass) case class Schema(dataType: DataType, nullable: Boolean) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 536d38777f89d..f8f4799322b3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -838,6 +838,8 @@ class Analyzer( // attributes that its child might have or could have. val missing = missingAttrs -- g.child.outputSet g.copy(join = true, child = addMissingAttr(g.child, missing)) + case d: Distinct => + throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9c06069f24f76..9a7c2a944b588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -287,7 +287,8 @@ trait CheckAnalysis extends PredicateHelper { } // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => - if (dt1 != dt2) { + // SPARK-18058: we shall not care about the nullability of columns + if (dt1.asNullable != dt2.asNullable) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index dd93b467eeeb2..a5e02523d2889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} +import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -196,6 +197,19 @@ abstract class ExternalCatalog { table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] + /** + * List the metadata of partitions that belong to the specified table, assuming it exists, that + * satisfy the given partition-pruning predicate expressions. + * + * @param db database name + * @param table table name + * @param predicates partition-pruning predicates + */ + def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 3e31127118b44..f95c9f8cfa2d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils /** @@ -477,6 +478,15 @@ class InMemoryCatalog( catalog(db).tables(table).partitions.values.toSeq } + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + // TODO: Provide an implementation + throw new UnsupportedOperationException( + "listPartitionsByFilter is not implemented for InMemoryCatalog") + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index fe41c41a6eb20..3d6eec81c03c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -462,11 +462,20 @@ class SessionCatalog( * If a database is specified in `oldName`, this will rename the table in that database. * If no database is specified, this will first attempt to rename a temporary table with * the same name, then, if that does not exist, rename the table in the current database. + * + * This assumes the database specified in `newName` matches the one in `oldName`. */ - def renameTable(oldName: TableIdentifier, newName: String): Unit = synchronized { + def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized { val db = formatDatabaseName(oldName.database.getOrElse(currentDb)) + newName.database.map(formatDatabaseName).foreach { newDb => + if (db != newDb) { + throw new AnalysisException( + s"RENAME TABLE source and destination databases do not match: '$db' != '$newDb'") + } + } + val oldTableName = formatTableName(oldName.table) - val newTableName = formatTableName(newName) + val newTableName = formatTableName(newName.table) if (db == globalTempViewManager.database) { globalTempViewManager.rename(oldTableName, newTableName) } else { @@ -476,6 +485,11 @@ class SessionCatalog( requireTableNotExists(TableIdentifier(newTableName, Some(db))) externalCatalog.renameTable(db, oldTableName, newTableName) } else { + if (newName.database.isDefined) { + throw new AnalysisException( + s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': cannot specify database " + + s"name '${newName.database.get}' in the destination table") + } if (tempTables.contains(newTableName)) { throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + "destination table already exists") @@ -741,6 +755,20 @@ class SessionCatalog( externalCatalog.listPartitions(db, table, partialSpec) } + /** + * List the metadata of partitions that belong to the specified table, assuming it exists, that + * satisfy the given partition-pruning predicate expressions. + */ + def listPartitionsByFilter( + tableName: TableIdentifier, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + externalCatalog.listPartitionsByFilter(db, table, predicates) + } + /** * Verify if the input partition spec exactly matches the existing defined partition spec * The columns must be the same but the orders could be different. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 51326ca25e9cc..7c3bec897956a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.catalog import java.util.Date import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** @@ -89,14 +89,24 @@ case class CatalogTablePartition( parameters: Map[String, String] = Map.empty) { override def toString: String = { + val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") val output = Seq( - s"Partition Values: [${spec.values.mkString(", ")}]", + s"Partition Values: [$specString]", s"$storage", s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") } + + /** + * Given the partition schema, returns a row with that schema holding the partition values. + */ + def toRow(partitionSchema: StructType): InternalRow = { + InternalRow.fromSeq(partitionSchema.map { field => + Cast(Literal(spec(field.name)), field.dataType).eval() + }) + } } @@ -128,6 +138,8 @@ case class BucketSpec( * Can be None if this table is a View, should be "hive" for hive serde tables. * @param unsupportedFeatures is a list of string descriptions of features that are used by the * underlying table but not supported by Spark SQL yet. + * @param partitionProviderIsHive whether this table's partition metadata is stored in the Hive + * metastore. */ case class CatalogTable( identifier: TableIdentifier, @@ -145,7 +157,8 @@ case class CatalogTable( viewOriginalText: Option[String] = None, viewText: Option[String] = None, comment: Option[String] = None, - unsupportedFeatures: Seq[String] = Seq.empty) { + unsupportedFeatures: Seq[String] = Seq.empty, + partitionProviderIsHive: Boolean = false) { /** schema of this table's partition columns */ def partitionSchema: StructType = StructType(schema.filter { @@ -203,11 +216,11 @@ case class CatalogTable( comment.map("Comment: " + _).getOrElse(""), if (properties.nonEmpty) s"Properties: $tableProperties" else "", if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", - s"$storage") + s"$storage", + if (partitionProviderIsHive) "Partition Provider: Hive" else "") output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index fa1a2ad56ccb3..9edc1ceff26a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -511,7 +511,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. - if (left.dataType != right.dataType) { + if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c0200299376ca..f56bb39d10791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -124,7 +124,13 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } case ArrayType(dt, _) => TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type ${dt.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 09e22aaf3e3d8..917aa0873130b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -427,18 +427,28 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } } - override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = { - val array = str.asInstanceOf[UTF8String] - .split(delim1.asInstanceOf[UTF8String], -1) - .map { kv => - val arr = kv.split(delim2.asInstanceOf[UTF8String], 2) - if (arr.length < 2) { - Array(arr(0), null) - } else { - arr - } + override def nullSafeEval( + inputString: Any, + stringDelimiter: Any, + keyValueDelimiter: Any): Any = { + val keyValues = + inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1) + + val iterator = new Iterator[(UTF8String, UTF8String)] { + var index = 0 + val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] + + override def hasNext: Boolean = { + keyValues.length > index } - ArrayBasedMapData(array.map(_ (0)), array.map(_ (1))) + + override def next(): (UTF8String, UTF8String) = { + val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2) + index += 1 + (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1)) + } + } + ArrayBasedMapData(iterator, keyValues.size, identity, identity) } override def prettyName: String = "str_to_map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 138ef2a1dcc01..5ead16908732f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -618,6 +618,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = StringType override def foldable: Boolean = true override def nullable: Boolean = false + override def prettyName: String = "current_database" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 799858a6865e5..9394e39aadd9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -84,8 +84,9 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns - * `false`. + * - `canEvaluate(EqualTo(a,b), R)` returns `true` + * - `canEvaluate(EqualTo(a,c), R)` returns `false` + * - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2626057e492ef..180ad2e0ad1fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val conditionalJoin = rest.find { planJoinPair => val plan = planJoinPair._1 val refs = left.outputSet ++ plan.outputSet - conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + conditions + .filterNot(l => l.references.nonEmpty && canEvaluate(l, left)) + .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan)) .exists(_.references.subsetOf(refs)) } // pick the next one if no condition left diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 929c1c4f2d9e4..38e9bb6c162ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -192,11 +192,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { override def visitPartitionSpec( ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { val parts = ctx.partitionVal.asScala.map { pVal => - val name = pVal.identifier.getText.toLowerCase + val name = pVal.identifier.getText val value = Option(pVal.constant).map(visitStringConstant) name -> value } - // Check for duplicate partition columns in one spec. + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. checkDuplicateKeys(parts, ctx) parts.toMap } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bdae56881bf46..c5f92c59c88f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) val joinKeys = predicates.flatMap { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) // Replace null with default value for joining key, then those rows with null in it could @@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case other => None } val otherPredicates = predicates.filterNot { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false case EqualTo(l, r) => canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, right) && canEvaluate(r, left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0fb6e7d2e795a..45ee2964d4db0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -68,26 +68,104 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case _ => Seq.empty[Attribute] } + // Collect aliases from expressions, so we may avoid producing recursive constraints. + private lazy val aliasMap = AttributeMap( + (expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) + }) + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5` + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val constraintClasses = generateEquivalentConstraintClasses(constraints) + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) => r + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, constraintClasses) => r }) - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) => l + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, constraintClasses) => l }) case _ => // No inference } inferredConstraints -- constraints } + /* + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { + // Combine the two sets. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty + // Update equivalence class of `left` expression. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty + // Update equivalence class of `right` expression. + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) + } + case _ => // Skip + } + + constraintClasses + } + + /* + * Get all expressions equivalent to the selected expression. + */ + private def getConstraintClass( + expr: Expression, + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) + + /* + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is a child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + constraintClasses: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getConstraintClass(expr, constraintClasses).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 09725473a384d..b0a4145f37767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -293,15 +293,19 @@ abstract class UnaryNode extends LogicalPlan { * expressions with the corresponding alias */ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { - projectList.flatMap { + var allConstraints = child.constraints.asInstanceOf[Set[Expression]] + projectList.foreach { case a @ Alias(e, _) => - child.constraints.map(_ transform { + // For every alias in `projectList`, replace the reference in constraints by its attribute. + allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet + }) + allConstraints += EqualNullSafe(e, a.toAttribute) + case _ => // Don't change. + } + + allConstraints -- child.constraints } override protected def validConstraints: Set[Expression] = child.constraints diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 43455c989c0f4..f3e2147b8f974 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -98,7 +98,7 @@ case class StringColumnStat(statRow: InternalRow) { // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. val numNulls: Long = statRow.getLong(0) val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getLong(2) + val maxColLen: Long = statRow.getInt(2) val ndv: Long = statRow.getLong(3) } @@ -106,7 +106,7 @@ case class BinaryColumnStat(statRow: InternalRow) { // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. val numNulls: Long = statRow.getLong(0) val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getLong(2) + val maxColLen: Long = statRow.getInt(2) } case class BooleanColumnStat(statRow: InternalRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d2d33e40a8c8f..a48974c6322ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -117,6 +117,8 @@ case class Filter(condition: Expression, child: LogicalPlan) abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + protected def leftConstraints: Set[Expression] = left.constraints protected def rightConstraints: Set[Expression] = { @@ -126,6 +128,13 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar case a: Attribute => attributeRewrites(a) }) } + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => + l.dataType.asNullable == r.dataType.asNullable + } && duplicateResolved } object SetOperation { @@ -134,8 +143,6 @@ object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) @@ -144,14 +151,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - // Intersect are only resolved if they don't introduce ambiguous expression ids, - // since the Optimizer will convert Intersect to Join. - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { None @@ -172,19 +171,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override lazy val statistics: Statistics = { left.statistics.copy() } @@ -219,9 +210,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { child.output.length == children.head.output.length && // compare the data types with the first child child.output.zip(children.head.output).forall { - case (l, r) => l.dataType == r.dataType } + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } ) - children.length > 1 && childrenResolved && allChildrenCompatible } @@ -366,26 +356,10 @@ case class InsertIntoTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - lazy val expectedColumns = { - if (table.output.isEmpty) { - None - } else { - // Note: The parser (visitPartitionSpec in AstBuilder) already turns - // keys in partition to their lowercase forms. - val staticPartCols = partition.filter(_._2.isDefined).keySet - Some(table.output.filterNot(a => staticPartCols.contains(a.name))) - } - } - assert(overwrite || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) - override lazy val resolved: Boolean = - childrenResolved && table.resolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) - } - } + + override lazy val resolved: Boolean = childrenResolved && table.resolved } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index fefe5a3953a6e..0ab4c9016623e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -230,6 +230,19 @@ object AppendColumns { encoderFor[U].namedExpressions, child) } + + def apply[T : Encoder, U : Encoder]( + func: T => U, + inputAttributes: Seq[Attribute], + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes), + encoderFor[U].namedExpressions, + child) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 83cb375525832..ea8d8fef7bdf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -164,6 +164,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ret } + /** + * Returns a Seq containing the leaves in this tree. + */ + def collectLeaves(): Seq[BaseType] = { + this.collect { case p if p.children.isEmpty => p } + } + /** * Finds and returns the first [[TreeNode]] of the tree for which the given partial function * is defined (pre-order), and applies the partial function to it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 4449da13c083c..91b3139443696 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.{Map => JavaMap} + class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -30,12 +32,83 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte } object ArrayBasedMapData { - def apply(map: Map[Any, Any]): ArrayBasedMapData = { - val array = map.toArray - ArrayBasedMapData(array.map(_._1), array.map(_._2)) + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair of the input [[java.util.Map]] + * + * @param javaMap Input map + * @param keyConverter This function is applied over all the keys of the input map to + * obtain the output map's keys + * @param valueConverter This function is applied over all the values of the input map to + * obtain the output map's values + */ + def apply( + javaMap: JavaMap[_, _], + keyConverter: (Any) => Any, + valueConverter: (Any) => Any): ArrayBasedMapData = { + import scala.language.existentials + + val keys: Array[Any] = new Array[Any](javaMap.size()) + val values: Array[Any] = new Array[Any](javaMap.size()) + + var i: Int = 0 + val iterator = javaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + keys(i) = keyConverter(entry.getKey) + values(i) = valueConverter(entry.getValue) + i += 1 + } + ArrayBasedMapData(keys, values) + } + + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair of the input map + * + * @param map Input map + * @param keyConverter This function is applied over all the keys of the input map to + * obtain the output map's keys + * @param valueConverter This function is applied over all the values of the input map to + * obtain the output map's values + */ + def apply( + map: scala.collection.Map[_, _], + keyConverter: (Any) => Any = identity, + valueConverter: (Any) => Any = identity): ArrayBasedMapData = { + ArrayBasedMapData(map.iterator, map.size, keyConverter, valueConverter) + } + + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair from the given iterator + * + * @param iterator Input iterator + * @param size Number of elements + * @param keyConverter This function is applied over all the keys extracted from the + * given iterator to obtain the output map's keys + * @param valueConverter This function is applied over all the values extracted from the + * given iterator to obtain the output map's values + */ + def apply( + iterator: Iterator[(_, _)], + size: Int, + keyConverter: (Any) => Any, + valueConverter: (Any) => Any): ArrayBasedMapData = { + + val keys: Array[Any] = new Array[Any](size) + val values: Array[Any] = new Array[Any](size) + + var i = 0 + for ((key, value) <- iterator) { + keys(i) = keyConverter(key) + values(i) = valueConverter(value) + i += 1 + } + ArrayBasedMapData(keys, values) } - def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { + def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index c741a2dd3ea30..b18fba29af0f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.types import scala.language.existentials -private[sql] object ObjectType extends AbstractDataType { +import org.apache.spark.annotation.InterfaceStability + +@InterfaceStability.Evolving +object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException("null literals can't be casted to ObjectType") @@ -32,11 +35,10 @@ private[sql] object ObjectType extends AbstractDataType { } /** - * Represents a JVM object that is passing through Spark SQL expression evaluation. Note this - * is only used internally while converting into the internal format and is not intended for use - * outside of the execution engine. + * Represents a JVM object that is passing through Spark SQL expression evaluation. */ -private[sql] case class ObjectType(cls: Class[_]) extends DataType { +@InterfaceStability.Evolving +case class ObjectType(cls: Class[_]) extends DataType { override def defaultSize: Int = 4096 def asNullable: DataType = this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 50ebad25cd258..590774c043040 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -377,4 +377,23 @@ class AnalysisSuite extends AnalysisTest { assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } + + test("SPARK-18058: union and set operations shall not care about the nullability" + + " when comparing column types") { + val firstTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)()) + val secondTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) + + val unionPlan = Union(firstTable, secondTable) + assertAnalysisSuccess(unionPlan) + + val r1 = Except(firstTable, secondTable) + val r2 = Intersect(firstTable, secondTable) + + assertAnalysisSuccess(r1) + assertAnalysisSuccess(r2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 915ed8f8b1787..187611bc77460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -273,27 +273,34 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl1", Some("db2")), "tblone") + sessionCatalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), "tbltwo") + sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) // Rename table without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameTable(TableIdentifier("tbltwo"), "table_two") + sessionCatalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) assert(externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) + // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match + intercept[AnalysisException] { + sessionCatalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) + } // The new table already exists intercept[TableAlreadyExistsException] { - sessionCatalog.renameTable(TableIdentifier("tblone", Some("db2")), "table_two") + sessionCatalog.renameTable( + TableIdentifier("tblone", Some("db2")), + TableIdentifier("table_two")) } } test("rename table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[NoSuchDatabaseException] { - catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), "tbl2") + catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) } intercept[NoSuchTableException] { - catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), "tbl2") + catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) } } @@ -306,12 +313,12 @@ class SessionCatalogSuite extends SparkFunSuite { assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be renamed first - sessionCatalog.renameTable(TableIdentifier("tbl1"), "tbl3") + sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) assert(sessionCatalog.getTempView("tbl1").isEmpty) assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is specified, temp tables are never renamed - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), "tbl4") + sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) assert(sessionCatalog.getTempView("tbl4").isEmpty) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 4df9062018995..4d896c2e38f10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -66,8 +66,6 @@ case class RepeatedData( mapFieldNull: scala.collection.Map[Int, java.lang.Long], structField: PrimitiveData) -case class SpecificCollection(l: List[Int]) - /** For testing Kryo serialization based encoder. */ class KryoSerializable(val value: Int) { override def hashCode(): Int = value @@ -107,6 +105,12 @@ class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { } } +case class PrimitiveValueClass(wrapped: Int) extends AnyVal +case class ReferenceValueClass(wrapped: ReferenceValueClass.Container) extends AnyVal +object ReferenceValueClass { + case class Container(data: Int) +} + class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -290,6 +294,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + encodeDecodeTest( + PrimitiveValueClass(42), "primitive value class") + + encodeDecodeTest( + ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("nullable of encoder schema") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index fdb9fa31f09c8..26978a0482fc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -215,13 +215,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Substring(bytes, 2, 2), Array[Byte](2, 3)) checkEvaluation(Substring(bytes, 3, 2), Array[Byte](3, 4)) checkEvaluation(Substring(bytes, 4, 2), Array[Byte](4)) - checkEvaluation(Substring(bytes, 8, 2), Array[Byte]()) + checkEvaluation(Substring(bytes, 8, 2), Array.empty[Byte]) checkEvaluation(Substring(bytes, -1, 2), Array[Byte](4)) checkEvaluation(Substring(bytes, -2, 2), Array[Byte](3, 4)) checkEvaluation(Substring(bytes, -3, 2), Array[Byte](2, 3)) checkEvaluation(Substring(bytes, -4, 2), Array[Byte](1, 2)) checkEvaluation(Substring(bytes, -5, 2), Array[Byte](1)) - checkEvaluation(Substring(bytes, -8, 2), Array[Byte]()) + checkEvaluation(Substring(bytes, -8, 2), Array.empty[Byte]) } test("string substring_index function") { @@ -275,7 +275,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA==")) checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) - checkEvaluation(Base64(b), "", create_row(Array[Byte]())) + checkEvaluation(Base64(b), "", create_row(Array.empty[Byte])) checkEvaluation(Base64(b), null, create_row(null)) checkEvaluation(Base64(Literal.create(null, BinaryType)), null, create_row("abdef")) @@ -526,13 +526,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // non ascii characters are not allowed in the source code, so we disable the scalastyle. checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) // scalastyle:on - checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array.empty[Byte])) checkEvaluation(Length(a), 5, create_row(string)) checkEvaluation(Length(b), 5, create_row(bytes)) checkEvaluation(Length(a), 0, create_row("")) - checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + checkEvaluation(Length(b), 0, create_row(Array.empty[Byte])) checkEvaluation(Length(a), null, create_row(null)) checkEvaluation(Length(b), null, create_row(null)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e7fdd5a6202b6..9f57f66a2ea20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._ class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) :: - Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) :: - Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints, + CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -120,4 +123,82 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("inner join with alias: alias contains multiple attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))) + .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2.where(IsNotNull('a)), Inner, + Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: alias contains single attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .select('a, 'b.as('d)).as("t") + .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: don't generate constraints for recursive functions") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) + && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) + && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr + && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("generate correct filters for alias that don't produce recursive constraints") { + val t1 = testRelation.subquery('t1) + + val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze + val correctAnswer = + t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b)) + .select('a.as('x), 'b.as('y)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 8d6a49a8a37b4..8068ce922e636 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -128,8 +128,16 @@ class ConstraintPropagationSuite extends SparkFunSuite { ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + + val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) + verifyConstraints(multiAlias.analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), + IsNotNull(resolveColumn(multiAlias.analyze, "y")), + resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) + ) } test("propagating constraints in union") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6310f0c2bc0ed..64e268703bf5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) * etc., will all now be equivalent. * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. */ private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) case sample: Sample => sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) } } + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizePlan(normalizeExprIds(plan1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index cb0426c7a98a1..3eff12f9eed14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -489,6 +489,7 @@ class TreeNodeSuite extends SparkFunSuite { "owner" -> "", "createTime" -> 0, "lastAccessTime" -> -1, + "partitionProviderIsHive" -> false, "properties" -> JNull, "unsupportedFeatures" -> List.empty[String])) diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 9665c3c46f901..1c3c9794fb6bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -16,11 +16,14 @@ */ package org.apache.spark.sql; +import org.apache.spark.annotation.InterfaceStability; + /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ +@InterfaceStability.Stable public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index ef959e35e1027..1460daf27dc20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 1 arguments. */ +@InterfaceStability.Stable public interface UDF1 extends Serializable { - public R call(T1 t1) throws Exception; + R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java index 96ab3a96c3d5e..7c4f1e4897084 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 10 arguments. */ +@InterfaceStability.Stable public interface UDF10 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java index 58ae8edd6d817..26a05106aebd6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 11 arguments. */ +@InterfaceStability.Stable public interface UDF11 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java index d9da0f6eddd94..8ef7a99042025 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 12 arguments. */ +@InterfaceStability.Stable public interface UDF12 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java index 095fc1a8076b5..5c3b2ec1222e2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 13 arguments. */ +@InterfaceStability.Stable public interface UDF13 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java index eb27eaa180086..97e744d843466 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 14 arguments. */ +@InterfaceStability.Stable public interface UDF14 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java index 1fbcff56332b6..7ddbf914fc11a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 15 arguments. */ +@InterfaceStability.Stable public interface UDF15 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java index 1133561787a69..0ae5dc7195ad6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 16 arguments. */ +@InterfaceStability.Stable public interface UDF16 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java index dfae7922c9b63..03543a556c614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 17 arguments. */ +@InterfaceStability.Stable public interface UDF17 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java index e9d1c6d52d4ea..46740d3443916 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 18 arguments. */ +@InterfaceStability.Stable public interface UDF18 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java index 46b9d2d3c9457..33fefd8ecaf1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 19 arguments. */ +@InterfaceStability.Stable public interface UDF19 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java index cd3fde8da419e..9822f19217d76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 2 arguments. */ +@InterfaceStability.Stable public interface UDF2 extends Serializable { - public R call(T1 t1, T2 t2) throws Exception; + R call(T1 t1, T2 t2) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java index 113d3d26be4a7..8c5e90182da1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 20 arguments. */ +@InterfaceStability.Stable public interface UDF20 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java index 74118f2cf8da7..e3b09f5167cff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 21 arguments. */ +@InterfaceStability.Stable public interface UDF21 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java index 0e7cc40be45ec..dc6cfa9097bab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 22 arguments. */ +@InterfaceStability.Stable public interface UDF22 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java index 6a880f16be47a..7c264b69ba195 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 3 arguments. */ +@InterfaceStability.Stable public interface UDF3 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3) throws Exception; + R call(T1 t1, T2 t2, T3 t3) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java index fcad2febb18e6..58df38fc3c911 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 4 arguments. */ +@InterfaceStability.Stable public interface UDF4 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java index ce0cef43a2144..4146f96e2eed5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 5 arguments. */ +@InterfaceStability.Stable public interface UDF5 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java index f56b806684e61..25d39654c1095 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 6 arguments. */ +@InterfaceStability.Stable public interface UDF6 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java index 25bd6d3241bd4..ce63b6a91adbb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 7 arguments. */ +@InterfaceStability.Stable public interface UDF7 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java index a3b7ac5f94ce7..0e00209ef6b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 8 arguments. */ +@InterfaceStability.Stable public interface UDF8 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java index 205e72a1522fc..077981bb3e3ee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 9 arguments. */ +@InterfaceStability.Stable public interface UDF9 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index 247e94b86c349..ec9c107b1c119 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions.javalang; import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; @@ -34,6 +35,7 @@ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving public class typed { // Note: make sure to keep in sync with typed.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d22bb17934ce7..05e867bf5be96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.language.implicitConversions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -1181,13 +1181,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** - * :: Experimental :: * A convenient class used for constructing schema. * * @since 1.3.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 65a9c008f9650..0d43f09bc54cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,20 +21,18 @@ import java.{lang => jl} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: Experimental :: * Functionality for working with missing data in [[DataFrame]]s. * * @since 1.3.1 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a716a916b7f7f..a77937efd7e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -232,14 +232,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parts: Array[Partition], connectionProperties: Properties): DataFrame = { // connectionProperties should override settings in extraOptions. - val params = extraOptions.toMap ++ connectionProperties.asScala.toMap - val options = new JDBCOptions(url, table, params) - val relation = JDBCRelation(parts, options)(sparkSession) - sparkSession.baseRelationToDataFrame(relation) + this.extraOptions = this.extraOptions ++ connectionProperties.asScala + // explicit url and dbtable should override all + this.extraOptions += ("url" -> url, "dbtable" -> table) + format("jdbc").load() } /** - * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file ([[http://jsonlines.org/ JSON Lines text format or newline-delimited JSON]]) + * and returns the result as a [[DataFrame]]. * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 @@ -250,7 +251,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file ([[http://jsonlines.org/ JSON Lines text format or newline-delimited JSON]]) + * and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -295,8 +297,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def json(paths: String*): DataFrame = format("json").load(paths : _*) /** - * Loads a `JavaRDD[String]` storing JSON objects (one object per record) and - * returns the result as a [[DataFrame]]. + * Loads a `JavaRDD[String]` storing JSON objects ([[http://jsonlines.org/ JSON Lines text format + * or newline-delimited JSON]]) and returns the result as a [[DataFrame]]. * * Unless the schema is specified using [[schema]] function, this function goes through the * input once to determine the input schema. @@ -307,8 +309,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) /** - * Loads an `RDD[String]` storing JSON objects (one object per record) and - * returns the result as a [[DataFrame]]. + * Loads an `RDD[String]` storing JSON objects ([[http://jsonlines.org/ JSON Lines text format or + * newline-delimited JSON]]) and returns the result as a [[DataFrame]]. * * Unless the schema is specified using [[schema]] function, this function goes through the * input once to determine the input schema. @@ -363,7 +365,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * type. *
  • `quote` (default `"`): sets the single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to - * set not `null` but an empty string. This behaviour is different form + * set not `null` but an empty string. This behaviour is different from * `com.databricks.spark.csv`.
  • *
  • `escape` (default `\`): sets the single character used for escaping quotes inside * an already quoted value.
  • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a212bb6205328..b5bbcee37150f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,20 +21,18 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** - * :: Experimental :: * Statistic functions for [[DataFrame]]s. * * @since 1.4.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 35ef050dcb169..11dd1df909938 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Union} +import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -387,8 +388,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec ) - val cmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) - df.sparkSession.sessionState.executePlan(cmd).toRdd + df.sparkSession.sessionState.executePlan( + CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd + if (tableDesc.partitionColumnNames.nonEmpty && + df.sparkSession.sqlContext.conf.manageFilesourcePartitions) { + // Need to recover partitions into the metastore so our saved data is visible. + df.sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(tableDesc.identifier)).toRdd + } } } @@ -426,15 +433,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - // connectionProperties should override settings in extraOptions - this.extraOptions = this.extraOptions ++ (connectionProperties.asScala) + // connectionProperties should override settings in extraOptions. + this.extraOptions = this.extraOptions ++ connectionProperties.asScala // explicit url and dbtable should override all this.extraOptions += ("url" -> url, "dbtable" -> table) format("jdbc").save() } /** - * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * Saves the content of the [[DataFrame]] in JSON format ([[http://jsonlines.org/ JSON Lines text + * format or newline-delimited JSON]]) at the specified path. * This is equivalent to: * {{{ * format("json").save(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e59a483075c94..286d8549bfe27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} @@ -556,7 +556,7 @@ class Dataset[T] private[sql]( * 1983 03 0.410516 0.442194 * 1984 04 0.450090 0.483521 * }}} - * + * * @param numRows Number of rows to show * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. @@ -774,7 +774,7 @@ class Dataset[T] private[sql]( * @param right Right side of the join operation. * * @group untypedrel - * @since 2.0.0 + * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Cross, None) @@ -1524,7 +1524,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. CombineUnions(Union(logicalPlan, other.logicalPlan)) @@ -1540,7 +1540,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { + def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { Intersect(logicalPlan, other.logicalPlan) } @@ -1554,7 +1554,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def except(other: Dataset[T]): Dataset[T] = withTypedPlan { + def except(other: Dataset[T]): Dataset[T] = withSetOperator { Except(logicalPlan, other.logicalPlan) } @@ -2401,6 +2401,18 @@ class Dataset[T] private[sql]( this } + /** + * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted. + * + * @group basic + * @since 2.1.0 + */ + def storageLevel: StorageLevel = { + sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => + cachedData.cachedRepresentation.storageLevel + }.getOrElse(StorageLevel.NONE) + } + /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. * @@ -2602,7 +2614,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def inputFiles: Array[String] = { - val files: Seq[String] = logicalPlan.collect { + val files: Seq[String] = queryExecution.optimizedPlan.collect { case LogicalRelation(fsBasedRelation: FileRelation, _, _) => fsBasedRelation.inputFiles case fr: FileRelation => @@ -2713,4 +2725,14 @@ class Dataset[T] private[sql]( @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } + + /** A convenient function to wrap a set based logical plan and produce a Dataset. */ + @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + // Set operators widen types (change the schema), so we cannot reuse the row encoder. + Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] + } else { + Dataset(sparkSession, logicalPlan) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 828eb94efe598..4cb0313aa9037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -66,6 +66,48 @@ class KeyValueGroupedDataset[K, V] private[sql]( dataAttributes, groupingAttributes) + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create values grouped by key from a Dataset[(K, V)] + * ds.groupByKey(_._1).mapValues(_._2) // Scala + * }}} + * + * @since 2.1.0 + */ + def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = { + val withNewData = AppendColumns(func, dataAttributes, logicalPlan) + val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData) + val executed = sparkSession.sessionState.executePlan(projected) + + new KeyValueGroupedDataset( + encoderFor[K], + encoderFor[W], + executed, + withNewData.newColumns, + groupingAttributes) + } + + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create Integer values grouped by String key from a Dataset> + * Dataset> ds = ...; + * KeyValueGroupedDataset grouped = + * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8 + * }}} + * + * @since 2.1.0 + */ + def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { + implicit val uEnc = encoder + mapValues { (v: V) => func.call(v) } + } + /** * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping * over the Dataset to extract the keys and then running a distinct operation on those. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 137c426b4b88d..3045eb69f427f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -64,7 +64,7 @@ import org.apache.spark.util.Utils * SparkSession.builder() * .master("local") * .appName("Word Count") - * .config("spark.some.config.option", "some-value"). + * .config("spark.some.config.option", "some-value") * .getOrCreate() * }}} */ @@ -814,7 +814,7 @@ object SparkSession { if ((session ne null) && !session.sparkContext.isStopped) { options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { - logWarning("Use an existing SparkSession, some configuration may not take effect.") + logWarning("Using an existing SparkSession; some configuration may not take effect.") } return session } @@ -826,7 +826,7 @@ object SparkSession { if ((session ne null) && !session.sparkContext.isStopped) { options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { - logWarning("Use an existing SparkSession, some configuration may not take effect.") + logWarning("Using an existing SparkSession; some configuration may not take effect.") } return session } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 617a14793697b..0444ad10d34fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql +import java.io.IOException +import java.lang.reflect.{ParameterizedType, Type} + import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. @@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Register a Java UDF class using reflection, for use from pyspark + * + * @param name udf name + * @param className fully qualified class name of udf + * @param returnDataType return type of udf. If it is null, spark would try to infer + * via reflection. + */ + private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + + try { + val clazz = Utils.classForName(className) + val udfInterfaces = clazz.getGenericInterfaces + .filter(_.isInstanceOf[ParameterizedType]) + .map(_.asInstanceOf[ParameterizedType]) + .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) + if (udfInterfaces.length == 0) { + throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + } else if (udfInterfaces.length > 1) { + throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + } else { + try { + val udf = clazz.newInstance() + val udfReturnType = udfInterfaces(0).getActualTypeArguments.last + var returnType = returnDataType + if (returnType == null) { + returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 + } + + udfInterfaces(0).getActualTypeArguments.length match { + case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) + case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) + case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) + case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) + case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) + case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) + case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) + case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) + case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) + case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) + case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case n => logError(s"UDF class with ${n} type arguments is not supported ") + } + } catch { + case e @ (_: InstantiationException | _: IllegalArgumentException) => + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + } catch { + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + } + + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 18cba8ce28b4d..aecdda1c36498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalog -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -27,6 +27,7 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ +@InterfaceStability.Stable abstract class Catalog { /** @@ -193,6 +194,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable(tableName: String, path: String): DataFrame /** @@ -203,6 +205,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable(tableName: String, path: String, source: String): DataFrame /** @@ -213,6 +216,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -227,6 +231,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -240,6 +245,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -255,6 +261,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -336,7 +343,8 @@ abstract class Catalog { /** * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that - * contains the given data source path. + * contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * everything that is cached. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 33032f07f7bea..c0c5ebc2ba2d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.DefinedByConstructorParams @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * @param locationUri path (in the form of a uri) to data files. * @since 2.0.0 */ +@InterfaceStability.Stable class Database( val name: String, @Nullable val description: String, @@ -59,6 +61,7 @@ class Database( * @param isTemporary whether the table is a temporary table. * @since 2.0.0 */ +@InterfaceStability.Stable class Table( val name: String, @Nullable val database: String, @@ -90,6 +93,7 @@ class Table( * @param isBucket whether the column is a bucket column. * @since 2.0.0 */ +@InterfaceStability.Stable class Column( val name: String, @Nullable val description: String, @@ -122,6 +126,7 @@ class Column( * @param isTemporary whether the function is a temporary function or not. * @since 2.0.0 */ +@InterfaceStability.Stable class Function( val name: String, @Nullable val database: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 83b7c779ab818..526623a36d2a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -177,7 +177,7 @@ class CacheManager extends Logging { /** * Traverses a given `plan` and searches for the occurrences of `qualifiedPath` in the - * [[org.apache.spark.sql.execution.datasources.FileCatalog]] of any [[HadoopFsRelation]] nodes + * [[org.apache.spark.sql.execution.datasources.FileIndex]] of any [[HadoopFsRelation]] nodes * in the plan. If found, we refresh the metadata and return true. Otherwise, this method returns * false. */ @@ -185,9 +185,10 @@ class CacheManager extends Logging { plan match { case lr: LogicalRelation => lr.relation match { case hr: HadoopFsRelation => - val invalidate = hr.location.paths - .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory)) - .contains(qualifiedPath) + val prefixToInvalidate = qualifiedPath.toString + val invalidate = hr.location.rootPaths + .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory).toString) + .exists(_.startsWith(prefixToInvalidate)) if (invalidate) hr.location.refresh() invalidate case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 9c7c0ffad0d60..5e1e4e395a49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -225,13 +225,27 @@ case class FileSourceScanExec( } // These metadata values make scan plans uniquely identifiable for equality checking. - override val metadata: Map[String, String] = Map( - "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, - "Batched" -> supportsBatch.toString, - "PartitionFilters" -> partitionFilters.mkString("[", ", ", "]"), - "PushedFilters" -> dataFilters.mkString("[", ", ", "]"), - "InputPaths" -> relation.location.paths.mkString(", ")) + override val metadata: Map[String, String] = { + def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") + val location = relation.location + val locationDesc = + location.getClass.getSimpleName + seqToString(location.rootPaths) + val metadata = + Map( + "Format" -> relation.fileFormat.toString, + "ReadSchema" -> outputSchema.catalogString, + "Batched" -> supportsBatch.toString, + "PartitionFilters" -> seqToString(partitionFilters), + "PushedFilters" -> seqToString(dataFilters), + "Location" -> locationDesc) + val withOptPartitionCount = + relation.partitionSchemaOption.map { _ => + metadata + ("PartitionCount" -> selectedPartitions.size.toString) + } getOrElse { + metadata + } + withOptPartitionCount + } private lazy val inputRDD: RDD[InternalRow] = { val originalPartitions = relation.location.listFiles(partitionFilters) @@ -435,7 +449,7 @@ case class FileSourceScanExec( private def createBucketedReadRDD( bucketSpec: BucketSpec, readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Seq[Partition], + selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") val bucketed = @@ -467,7 +481,7 @@ case class FileSourceScanExec( */ private def createNonBucketedReadRDD( readFile: (PartitionedFile) => Iterator[InternalRow], - selectedPartitions: Seq[Partition], + selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { val defaultMaxSplitBytes = fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 39189a2b0c72c..2663129562660 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -60,6 +61,8 @@ case class GenerateExec( override def producedAttributes: AttributeSet = AttributeSet(output) + override def outputPartitioning: Partitioning = child.outputPartitioning + val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 6598fa381aa3d..e366b9af35c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -64,10 +64,13 @@ case class LocalTableScanExec( } override def executeCollect(): Array[InternalRow] = { + longMetric("numOutputRows").add(unsafeRows.size) unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - unsafeRows.take(limit) + val taken = unsafeRows.take(limit) + longMetric("numOutputRows").add(taken.size) + taken } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index d8e0675e3eb65..cc576bbc4c802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -45,6 +45,10 @@ case class SortExec( override def outputOrdering: Seq[SortOrder] = sortOrder + // sort performed is local within a given partition so will retain + // child operator's partitioning + override def outputPartitioning: Partitioning = child.outputPartitioning + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8b762b5d6c5f2..981728331d361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf @@ -32,5 +33,6 @@ class SparkOptimizer( override def batches: Seq[Batch] = super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 48d6ef6dcd44a..24d0cffef82a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -395,8 +395,6 @@ trait UnaryExecNode extends SparkPlan { def child: SparkPlan override final def children: Seq[SparkPlan] = child :: Nil - - override def outputPartitioning: Partitioning = child.outputPartitioning } trait BinaryExecNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index be2eddbb0e423..fe183d0097d03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -98,9 +98,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec == null && - ctx.identifier != null && - ctx.identifier.getText.toLowerCase == "noscan") { + if (ctx.partitionSpec != null) { + logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") + } + if (ctx.identifier != null) { + if (ctx.identifier.getText.toLowerCase != "noscan") { + throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) + } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) } else if (ctx.identifierSeq() == null) { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) @@ -168,17 +172,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitShowColumns(ctx: ShowColumnsContext): LogicalPlan = withOrigin(ctx) { - val table = visitTableIdentifier(ctx.tableIdentifier) - - val lookupTable = Option(ctx.db) match { - case None => table - case Some(db) if table.database.exists(_ != db) => - operationNotAllowed( - s"SHOW COLUMNS with conflicting databases: '$db' != '${table.database.get}'", - ctx) - case Some(db) => TableIdentifier(table.identifier, Some(db.getText)) - } - ShowColumnsCommand(lookupTable) + ShowColumnsCommand(Option(ctx.db).map(_.getText), visitTableIdentifier(ctx.tableIdentifier)) } /** @@ -689,15 +683,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { - val fromName = visitTableIdentifier(ctx.from) - val toName = visitTableIdentifier(ctx.to) - if (toName.database.isDefined) { - operationNotAllowed("Can not specify database in table/view name after RENAME TO", ctx) - } - AlterTableRenameCommand( - fromName, - toName.table, + visitTableIdentifier(ctx.from), + visitTableIdentifier(ctx.to), ctx.VIEW != null) } @@ -1010,9 +998,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), outputFormat = defaultHiveSerde.flatMap(_.outputFormat) .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - // Note: Keep this unspecified because we use the presence of the serde to decide - // whether to convert a table created by CTAS to a datasource table. - serde = None, + serde = defaultHiveSerde.flatMap(_.serde), compressed = false, properties = Map()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 62bf6f4a81eec..6303483f22fd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -218,7 +218,9 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def doExecute(): RDD[InternalRow] = { @@ -292,7 +294,9 @@ object WholeStageCodegenExec { case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 06199ef3e8243..4529ed067e565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -63,6 +63,8 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override def outputPartitioning: Partitioning = child.outputPartitioning + override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c44b3..be3198b8e7d82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -66,6 +66,8 @@ case class SortAggregateExec( groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index dd78a784915d2..a5291e0c12f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -78,6 +78,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning } @@ -214,6 +216,8 @@ case class FilterExec(condition: Expression, child: SparkPlan) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning } /** @@ -234,6 +238,8 @@ case class SampleExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -517,7 +523,9 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def sameResult(o: SparkPlan): Boolean = o match { @@ -562,7 +570,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { } override def executeCollect(): Array[InternalRow] = { - ThreadUtils.awaitResult(relationFuture, Duration.Inf) + ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7066378279971..f873f34a845ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -50,7 +50,8 @@ case class AnalyzeColumnCommand( AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + updateStats(logicalRel.catalogTable.get, + AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => throw new AnalysisException("ANALYZE TABLE is not supported for " + @@ -59,10 +60,12 @@ case class AnalyzeColumnCommand( def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { val (rowCount, columnStats) = computeColStats(sparkSession, relation) + // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( sizeInBytes = newTotalSize, rowCount = Some(rowCount), - colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) @@ -90,8 +93,9 @@ case class AnalyzeColumnCommand( } } if (duplicatedColumns.nonEmpty) { - logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + - s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + + s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + + s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") } // Collect statistics per column. @@ -116,22 +120,24 @@ case class AnalyzeColumnCommand( } object ColumnStatStruct { - val zero = Literal(0, LongType) - val one = Literal(1, LongType) + private val zero = Literal(0, LongType) + private val one = Literal(1, LongType) - def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - def max(e: Expression): Expression = Max(e) - def min(e: Expression): Expression = Min(e) - def ndv(e: Expression, relativeSD: Double): Expression = { + private def numNulls(e: Expression): Expression = { + if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + } + private def max(e: Expression): Expression = Max(e) + private def min(e: Expression): Expression = Min(e) + private def ndv(e: Expression, relativeSD: Double): Expression = { // the approximate ndv should never be larger than the number of rows Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) } - def avgLength(e: Expression): Expression = Average(Length(e)) - def maxLength(e: Expression): Expression = Max(Length(e)) - def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + private def avgLength(e: Expression): Expression = Average(Length(e)) + private def maxLength(e: Expression): Expression = Max(Length(e)) + private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -139,19 +145,19 @@ object ColumnStatStruct { }) } - def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) } - def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) } - def binaryColumnStat(e: Expression): Seq[Expression] = { + private def binaryColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e)) } - def booleanColumnStat(e: Expression): Seq[Expression] = { + private def booleanColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), numTrues(e), numFalses(e)) } @@ -162,14 +168,14 @@ object ColumnStatStruct { } } - def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) - case StringType => getStruct(stringColumnStat(e, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(e)) - case BooleanType => getStruct(booleanColumnStat(e)) + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) + case StringType => getStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(attr)) + case BooleanType => getStruct(booleanColumnStat(attr)) case otherType => throw new AnalysisException("Analyzing columns is not supported for column " + - s"${e.name} of data type: ${e.dataType}.") + s"${attr.name} of data type: ${attr.dataType}.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 7b0e49b665f42..52a8fc88c56cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,7 +51,8 @@ case class AnalyzeTableCommand( // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + updateTableStats(logicalRel.catalogTable.get, + AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => throw new AnalysisException("ANALYZE TABLE is not supported for " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 698c625d617fc..d82e54e57564c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -66,6 +66,8 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index a04a13e698c43..2a9743130d4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -67,7 +67,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo dataSource match { case fs: HadoopFsRelation => - if (table.tableType == CatalogTableType.EXTERNAL && fs.location.paths.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL && fs.location.rootPaths.isEmpty) { throw new AnalysisException( "Cannot create a file-based external data source table without path") } @@ -94,10 +94,16 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo val newTable = table.copy( storage = table.storage.copy(properties = optionsWithPath), schema = dataSource.schema, - partitionColumnNames = partitionColumnNames) + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + partitionProviderIsHive = partitionColumnNames.nonEmpty && + sparkSession.sessionState.conf.manageFilesourcePartitions) // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) + Seq.empty[Row] } } @@ -232,6 +238,15 @@ case class CreateDataSourceTableAsSelectCommand( sessionState.catalog.createTable(newTable, ignoreIfExists = false) } + result match { + case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && + sparkSession.sqlContext.conf.manageFilesourcePartitions => + // Need to recover partitions into the metastore so our saved data is visible. + sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(table.identifier)).toRdd + case _ => + } + // Refresh the cache of the table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 45fa293e58951..61e0550cef5e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -28,10 +28,11 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -346,13 +347,15 @@ case class AlterTableAddPartitionCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API") - } + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE ADD PARTITION") val parts = partitionSpecsAndLocs.map { case (spec, location) => + val normalizedSpec = PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) - CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location)) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] @@ -377,13 +380,23 @@ case class AlterTableRenamePartitionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API") - } DDLUtils.verifyAlterTableType(catalog, table, isView = false) + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE RENAME PARTITION") + + val normalizedOldPartition = PartitioningUtils.normalizePartitionSpec( + oldPartition, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + + val normalizedNewPartition = PartitioningUtils.normalizePartitionSpec( + newPartition, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + catalog.renamePartitions( - tableName, Seq(oldPartition), Seq(newPartition)) + tableName, Seq(normalizedOldPartition), Seq(normalizedNewPartition)) Seq.empty[Row] } @@ -414,11 +427,18 @@ case class AlterTableDropPartitionCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE DROP PARTITION") + + val normalizedSpecs = specs.map { spec => + PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) } - catalog.dropPartitions(table.identifier, specs, ignoreIfNotExists = ifExists, purge = purge) + + catalog.dropPartitions( + table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge) Seq.empty[Row] } @@ -465,33 +485,39 @@ case class AlterTableRecoverPartitionsCommand( } } + private def getBasePath(table: CatalogTable): Option[String] = { + if (table.provider == Some("hive")) { + table.storage.locationUri + } else { + new CaseInsensitiveMap(table.storage.properties).get("path") + } + } + override def run(spark: SparkSession): Seq[Row] = { val catalog = spark.sessionState.catalog val table = catalog.getTableMetadata(tableName) val tableIdentWithDB = table.identifier.quotedString DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - s"Operation not allowed: $cmd on datasource tables: $tableIdentWithDB") - } if (table.partitionColumnNames.isEmpty) { throw new AnalysisException( s"Operation not allowed: $cmd only works on partitioned tables: $tableIdentWithDB") } - if (table.storage.locationUri.isEmpty) { + + val tablePath = getBasePath(table) + if (tablePath.isEmpty) { throw new AnalysisException(s"Operation not allowed: $cmd only works on table with " + s"location provided: $tableIdentWithDB") } - val root = new Path(table.storage.locationUri.get) + val root = new Path(tablePath.get) logInfo(s"Recover all the partitions in $root") val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) - val partitionSpecsAndLocs = scanPartitions( - spark, fs, pathFilter, root, Map(), table.partitionColumnNames.map(_.toLowerCase), threshold) + val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(), + table.partitionColumnNames, threshold, spark.sessionState.conf.resolver) val total = partitionSpecsAndLocs.length logInfo(s"Found $total partitions in $root") @@ -503,6 +529,11 @@ case class AlterTableRecoverPartitionsCommand( logInfo(s"Finished to gather the fast stats for all $total partitions.") addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + // Updates the table to indicate that its partition metadata is stored in the Hive metastore. + // This is always the case for Hive format tables, but is not true for Datasource tables created + // before Spark 2.1 unless they are converted via `msck repair table`. + spark.sessionState.catalog.alterTable(table.copy(partitionProviderIsHive = true)) + catalog.refreshTable(tableName) logInfo(s"Recovered all partitions ($total).") Seq.empty[Row] } @@ -516,7 +547,8 @@ case class AlterTableRecoverPartitionsCommand( path: Path, spec: TablePartitionSpec, partitionNames: Seq[String], - threshold: Int): GenSeq[(TablePartitionSpec, Path)] = { + threshold: Int, + resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } @@ -535,15 +567,15 @@ case class AlterTableRecoverPartitionsCommand( val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) - val columnName = PartitioningUtils.unescapePathName(ps(0)).toLowerCase + val columnName = PartitioningUtils.unescapePathName(ps(0)) // TODO: Validate the value val value = PartitioningUtils.unescapePathName(ps(1)) - // comparing with case-insensitive, but preserve the case - if (columnName == partitionNames.head) { - scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(columnName -> value), - partitionNames.drop(1), threshold) + if (resolver(columnName, partitionNames.head)) { + scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), + partitionNames.drop(1), threshold, resolver) } else { - logWarning(s"expect partition column ${partitionNames.head}, but got ${ps(0)}, ignore it") + logWarning( + s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") Seq() } } else { @@ -648,16 +680,11 @@ case class AlterTableSetLocationCommand( DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => + DDLUtils.verifyPartitionProviderIsHive( + sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition val part = catalog.getPartition(table.identifier, spec) - val newPart = - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE SET LOCATION for partition is not allowed for tables defined " + - "using the datasource API") - } else { - part.copy(storage = part.storage.copy(locationUri = Some(location))) - } + val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself @@ -681,6 +708,25 @@ object DDLUtils { table.provider.isDefined && table.provider.get != "hive" } + /** + * Throws a standard error for actions that require partitionProvider = hive. + */ + def verifyPartitionProviderIsHive( + spark: SparkSession, table: CatalogTable, action: String): Unit = { + val tableName = table.identifier.table + if (!spark.sqlContext.conf.manageFilesourcePartitions && isDatasourceTable(table)) { + throw new AnalysisException( + s"$action is not allowed on $tableName since filesource partition management is " + + "disabled (spark.sql.hive.manageFilesourcePartitions = false).") + } + if (!table.partitionProviderIsHive && isDatasourceTable(table)) { + throw new AnalysisException( + s"$action is not allowed on $tableName since its partition metadata is not stored in " + + "the Hive metastore. To import this information into the metastore, run " + + s"`msck repair table $tableName`") + } + } + /** * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, * issue an exception [[AnalysisException]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 424ef58d76c5e..4acfffb628047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.nio.file.FileSystems import java.util.Date import scala.collection.mutable.ArrayBuffer @@ -146,7 +147,7 @@ case class CreateTableCommand(table: CatalogTable, ifNotExists: Boolean) extends */ case class AlterTableRenameCommand( oldName: TableIdentifier, - newName: String, + newName: TableIdentifier, isView: Boolean) extends RunnableCommand { @@ -159,7 +160,6 @@ case class AlterTableRenameCommand( } else { val table = catalog.getTableMetadata(oldName) DDLUtils.verifyAlterTableType(catalog, table, isView) - val newTblName = TableIdentifier(newName, oldName.database) // If an exception is thrown here we can just assume the table is uncached; // this can happen with Hive tables when the underlying catalog is in-memory. val wasCached = Try(sparkSession.catalog.isCached(oldName.unquotedString)).getOrElse(false) @@ -172,7 +172,7 @@ case class AlterTableRenameCommand( } // For datasource tables, we also need to update the "path" serde property if (DDLUtils.isDatasourceTable(table) && table.tableType == CatalogTableType.MANAGED) { - val newPath = catalog.defaultTablePath(newTblName) + val newPath = catalog.defaultTablePath(newName) val newTable = table.withNewStorage( properties = table.storage.properties ++ Map("path" -> newPath)) catalog.alterTable(newTable) @@ -182,7 +182,7 @@ case class AlterTableRenameCommand( catalog.refreshTable(oldName) catalog.renameTable(oldName, newName) if (wasCached) { - sparkSession.catalog.cacheTable(newTblName.unquotedString) + sparkSession.catalog.cacheTable(newName.unquotedString) } } Seq.empty[Row] @@ -246,7 +246,27 @@ case class LoadDataCommand( val loadPath = if (isLocal) { val uri = Utils.resolveURI(path) - if (!new File(uri.getPath()).exists()) { + val filePath = uri.getPath() + val exists = if (filePath.contains("*")) { + val fileSystem = FileSystems.getDefault + val pathPattern = fileSystem.getPath(filePath) + val dir = pathPattern.getParent.toString + if (dir.contains("*")) { + throw new AnalysisException( + s"LOAD DATA input path allows only filename wildcard: $path") + } + + val files = new File(dir).listFiles() + if (files == null) { + false + } else { + val matcher = fileSystem.getPathMatcher("glob:" + pathPattern.toAbsolutePath) + files.exists(f => matcher.matches(fileSystem.getPath(f.getAbsolutePath))) + } + } else { + new File(filePath).exists() + } + if (!exists) { throw new AnalysisException(s"LOAD DATA input path does not exist: $path") } uri @@ -338,19 +358,16 @@ case class TruncateTableCommand( throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE on views: $tableIdentwithDB") } - val isDatasourceTable = DDLUtils.isDatasourceTable(table) - if (isDatasourceTable && partitionSpec.isDefined) { - throw new AnalysisException( - s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + - s"for tables created using the data sources API: $tableIdentwithDB") - } if (table.partitionColumnNames.isEmpty && partitionSpec.isDefined) { throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + s"for tables that are not partitioned: $tableIdentwithDB") } + if (partitionSpec.isDefined) { + DDLUtils.verifyPartitionProviderIsHive(spark, table, "TRUNCATE TABLE ... PARTITION") + } val locations = - if (isDatasourceTable) { + if (DDLUtils.isDatasourceTable(table)) { Seq(table.storage.properties.get("path")) } else if (table.partitionColumnNames.isEmpty) { Seq(table.storage.locationUri) @@ -433,7 +450,7 @@ case class DescribeTableCommand( describeFormattedTableInfo(metadata, result) } } else { - describeDetailedPartitionInfo(catalog, metadata, result) + describeDetailedPartitionInfo(sparkSession, catalog, metadata, result) } } @@ -472,6 +489,10 @@ case class DescribeTableCommand( describeStorageInfo(table, buffer) if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) + + if (DDLUtils.isDatasourceTable(table) && table.partitionProviderIsHive) { + append(buffer, "Partition Provider:", "Hive", "") + } } private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { @@ -508,6 +529,7 @@ case class DescribeTableCommand( } private def describeDetailedPartitionInfo( + spark: SparkSession, catalog: SessionCatalog, metadata: CatalogTable, result: ArrayBuffer[Row]): Unit = { @@ -515,10 +537,7 @@ case class DescribeTableCommand( throw new AnalysisException( s"DESC PARTITION is not allowed on a view: ${table.identifier}") } - if (DDLUtils.isDatasourceTable(metadata)) { - throw new AnalysisException( - s"DESC PARTITION is not allowed on a datasource table: ${table.identifier}") - } + DDLUtils.verifyPartitionProviderIsHive(spark, metadata, "DESC PARTITION") val partition = catalog.getPartition(table, partitionSpec) if (isExtended) { describeExtendedDetailedPartitionInfo(table, metadata, partition, result) @@ -651,14 +670,24 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio * SHOW COLUMNS (FROM | IN) table_identifier [(FROM | IN) database]; * }}} */ -case class ShowColumnsCommand(tableName: TableIdentifier) extends RunnableCommand { +case class ShowColumnsCommand( + databaseName: Option[String], + tableName: TableIdentifier) extends RunnableCommand { override val output: Seq[Attribute] = { AttributeReference("col_name", StringType, nullable = false)() :: Nil } override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val table = catalog.getTempViewOrPermanentTableMetadata(tableName) + val resolver = sparkSession.sessionState.conf.resolver + val lookupTable = databaseName match { + case None => tableName + case Some(db) if tableName.database.exists(!resolver(_, db)) => + throw new AnalysisException( + s"SHOW COLUMNS with conflicting databases: '$db' != '${tableName.database.get}'") + case Some(db) => TableIdentifier(tableName.identifier, Some(db)) + } + val table = catalog.getTempViewOrPermanentTableMetadata(lookupTable) table.schema.map { c => Row(c.name) } @@ -713,10 +742,7 @@ case class ShowPartitionsCommand( s"SHOW PARTITIONS is not allowed on a table that is not partitioned: $tableIdentWithDB") } - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a datasource table: $tableIdentWithDB") - } + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "SHOW PARTITIONS") /** * Validate the partitioning spec by making sure all the referenced columns are @@ -864,18 +890,11 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.properties.nonEmpty) { - val filteredProps = metadata.properties.filterNot { - // Skips "EXTERNAL" property for external tables - case (key, _) => key == "EXTERNAL" && metadata.tableType == EXTERNAL - } - - val props = filteredProps.map { case (key, value) => + val props = metadata.properties.map { case (key, value) => s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" } - if (props.nonEmpty) { - builder ++= props.mkString("TBLPROPERTIES (\n ", ",\n ", "\n)\n") - } + builder ++= props.mkString("TBLPROPERTIES (\n ", ",\n ", "\n)\n") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala new file mode 100644 index 0000000000000..092aabc89a36c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType + + +/** + * A [[FileIndex]] for a metastore catalog table. + * + * @param sparkSession a [[SparkSession]] + * @param table the metadata of the table + * @param sizeInBytes the table's data size in bytes + */ +class CatalogFileIndex( + sparkSession: SparkSession, + val table: CatalogTable, + override val sizeInBytes: Long) extends FileIndex { + + protected val hadoopConf = sparkSession.sessionState.newHadoopConf + + private val fileStatusCache = FileStatusCache.newCache(sparkSession) + + assert(table.identifier.database.isDefined, + "The table identifier must be qualified in CatalogFileIndex") + + private val baseLocation = table.storage.locationUri + + override def partitionSchema: StructType = table.partitionSchema + + override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq + + override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { + filterPartitions(filters).listFiles(Nil) + } + + override def refresh(): Unit = fileStatusCache.invalidateAll() + + /** + * Returns a [[InMemoryFileIndex]] for this table restricted to the subset of partitions + * specified by the given partition-pruning filters. + * + * @param filters partition-pruning filters + */ + def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { + if (table.partitionColumnNames.nonEmpty) { + val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( + table.identifier, filters) + val partitions = selectedPartitions.map { p => + PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) + } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + new PrunedInMemoryFileIndex( + sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) + } else { + new InMemoryFileIndex(sparkSession, rootPaths, table.storage.properties, None) + } + } + + override def inputFiles: Array[String] = filterPartitions(Nil).inputFiles + + // `CatalogFileIndex` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member + // of `LogicalRelation`, and `LogicalRelation` may be used as the cache key. So we need to + // implement `equals` and `hashCode` here, to make it work with cache lookup. + override def equals(o: Any): Boolean = o match { + case other: CatalogFileIndex => this.table.identifier == other.table.identifier + case _ => false + } + + override def hashCode(): Int = table.identifier.hashCode() +} + +/** + * An override of the standard HDFS listing based catalog, that overrides the partition spec with + * the information from the metastore. + * + * @param tableBasePath The default base path of the Hive metastore table + * @param partitionSpec The partition specifications from Hive metastore + */ +private class PrunedInMemoryFileIndex( + sparkSession: SparkSession, + tableBasePath: Path, + fileStatusCache: FileStatusCache, + override val partitionSpec: PartitionSpec) + extends InMemoryFileIndex( + sparkSession, + partitionSpec.partitions.map(_.path), + Map.empty, + Some(partitionSpec.partitionColumns), + fileStatusCache) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index e75e7d2770b4e..996109865fdc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -65,6 +65,8 @@ import org.apache.spark.util.Utils * @param partitionColumns A list of column names that the relation is partitioned by. When this * list is empty, the relation is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. + * @param catalogTable Optional catalog table reference that can be used to push down operations + * over the datasource to the catalog service. */ case class DataSource( sparkSession: SparkSession, @@ -73,9 +75,10 @@ case class DataSource( userSpecifiedSchema: Option[StructType] = None, partitionColumns: Seq[String] = Seq.empty, bucketSpec: Option[BucketSpec] = None, - options: Map[String, String] = Map.empty) extends Logging { + options: Map[String, String] = Map.empty, + catalogTable: Option[CatalogTable] = None) extends Logging { - case class SourceInfo(name: String, schema: StructType) + case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) lazy val providingClass: Class[_] = lookupDataSource(className) lazy val sourceInfo = sourceSchema() @@ -186,8 +189,11 @@ case class DataSource( } } - private def inferFileFormatSchema(format: FileFormat): StructType = { - userSpecifiedSchema.orElse { + /** + * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + */ + private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { + userSpecifiedSchema.map(_ -> partitionColumns).orElse { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => @@ -196,15 +202,15 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) - val partitionCols = fileCatalog.partitionSpec().partitionColumns.fields + val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + val partitionSchema = fileCatalog.partitionSpec().partitionColumns val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) inferred.map { inferredSchema => - StructType(inferredSchema ++ partitionCols) + StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") @@ -217,7 +223,7 @@ case class DataSource( case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( sparkSession.sqlContext, userSpecifiedSchema, className, options) - SourceInfo(name, schema) + SourceInfo(name, schema, Nil) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -246,7 +252,8 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - SourceInfo(s"FileSource[$path]", inferFileFormatSchema(format)) + val (schema, partCols) = inferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", schema, partCols) case _ => throw new UnsupportedOperationException( @@ -266,7 +273,13 @@ case class DataSource( throw new IllegalArgumentException("'path' is not specified") }) new FileStreamSource( - sparkSession, path, className, sourceInfo.schema, metadataPath, options) + sparkSession = sparkSession, + path = path, + fileFormatClassName = className, + schema = sourceInfo.schema, + partitionColumns = sourceInfo.partitionColumns, + metadataPath = metadataPath, + options = options) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -351,7 +364,7 @@ case class DataSource( case (format: FileFormat, _) if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = new MetadataLogFileCatalog(sparkSession, basePath) + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -402,9 +415,16 @@ case class DataSource( }) } - val fileCatalog = - new ListingFileCatalog( + val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.partitionProviderIsHive) { + new CatalogFileIndex( + sparkSession, + catalogTable.get, + catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) + } else { + new InMemoryFileIndex( sparkSession, globbedPaths, options, partitionSchema) + } val dataSchema = userSpecifiedSchema.map { schema => val equality = sparkSession.sessionState.conf.resolver @@ -413,7 +433,7 @@ case class DataSource( format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.allFiles()) + fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles()) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + @@ -422,7 +442,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSpec().partitionColumns, + partitionSchema = fileCatalog.partitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -471,9 +491,7 @@ case class DataSource( val existingPartitionColumns = Try { resolveRelation() .asInstanceOf[HadoopFsRelation] - .location - .partitionSpec() - .partitionColumns + .partitionSchema .fieldNames .toSeq }.getOrElse(Seq.empty[String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6f9ed50a02b09..f0bcf94eadc96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -30,11 +30,11 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -163,14 +163,14 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks - if (t.location.paths.size != 1) { + if (t.location.rootPaths.size != 1) { throw new AnalysisException( "Can only write data to relations with a single path.") } - val outputPath = t.location.paths.head + val outputPath = t.location.rootPaths.head val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.paths + case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append @@ -179,15 +179,24 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - InsertIntoHadoopFsRelationCommand( + val insertCmd = InsertIntoHadoopFsRelationCommand( outputPath, query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, - () => t.refresh(), + () => t.location.refresh(), t.options, query, mode) + + if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && + l.catalogTable.get.partitionProviderIsHive) { + // TODO(ekl) we should be more efficient here and only recover the newly added partitions + val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier) + Union(insertCmd, recoverPartitionCmd) + } else { + insertCmd + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala new file mode 100644 index 0000000000000..b31c4d51c7923 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + + +/** + * Used to read and write data stored in files to/from the [[InternalRow]] format. + */ +trait FileFormat { + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory + + /** + * Returns a [[OutputWriterFactory]] for generating output writers that can write data. + * This method is current used only by FileStreamSinkWriter to generate output writers that + * does not use output committers to write data. The OutputWriter generated by the returned + * [[OutputWriterFactory]] must implement the method `newWriter(path)`.. + */ + def buildWriter( + sqlContext: SQLContext, + dataSchema: StructType, + options: Map[String, String]): OutputWriterFactory = { + // TODO: Remove this default implementation when the other formats have been ported + throw new UnsupportedOperationException(s"buildWriter is not supported for $this") + } + + /** + * Returns whether this format support returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = { + false + } + + /** + * Allow FileFormats to have a pluggable way to utilize pushed filters to eliminate partitions + * before execution. By default no pruning is performed and the original partitioning is + * preserved. + */ + def filterPartitions( + filters: Seq[Filter], + schema: StructType, + conf: Configuration, + allFiles: Seq[FileStatus], + root: Path, + partitions: Seq[Partition]): Seq[Partition] = { + partitions + } + + /** + * Returns whether a file with `path` could be splitted or not. + */ + def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + false + } + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param dataSchema The global data schema. It can be either specified by the user, or + * reconciled/merged from all underlying data files. If any partition columns + * are contained in the files, they are preserved in this schema. + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be appended to the rows that + * are produced by the iterator. + * @param requiredSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } + + /** + * Exactly the same as [[buildReader]] except that the reader function returned by this method + * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]] + * returns. + */ + def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val dataReader = buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + + new (PartitionedFile => Iterator[InternalRow]) with Serializable { + private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + + private val joinedRow = new JoinedRow() + + // Using lazy val to avoid serialization + private lazy val appendPartitionColumns = + GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + override def apply(file: PartitionedFile): Iterator[InternalRow] = { + // Using local val to avoid per-row lazy val check (pre-mature optimization?...) + val converter = appendPartitionColumns + + // Note that we have to apply the converter even though `file.partitionValues` is empty. + // This is because the converter is also responsible for converting safe `InternalRow`s into + // `UnsafeRow`s. + dataReader(file).map { dataRow => + converter(joinedRow(dataRow, file.partitionValues)) + } + } + } + } + +} + +/** + * The base class file format that is based on text file. + */ +abstract class TextBasedFileFormat extends FileFormat { + private var codecFactory: CompressionCodecFactory = null + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala new file mode 100644 index 0000000000000..277223d52ec52 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType + +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class PartitionDirectory(values: InternalRow, files: Seq[FileStatus]) + +/** + * An interface for objects capable of enumerating the root paths of a relation as well as the + * partitions of a relation subject to some pruning expressions. + */ +trait FileIndex { + + /** + * Returns the list of root input paths from which the catalog will get files. There may be a + * single root path from which partitions are discovered, or individual partitions may be + * specified by each path. + */ + def rootPaths: Seq[Path] + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with no partition values. + * + * @param filters The filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] + + /** + * Returns the list of files that will be read when scanning this relation. This call may be + * very expensive for large tables. + */ + def inputFiles: Array[String] + + /** Refresh any cached file listings */ + def refresh(): Unit + + /** Sum of table file sizes, in bytes */ + def sizeInBytes: Long + + /** Schema of the partitioning columns, or the empty schema if the table is not partitioned. */ + def partitionSchema: StructType +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala new file mode 100644 index 0000000000000..7c2e6fd04d5db --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import com.google.common.cache._ +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.{SerializableConfiguration, SizeEstimator} + +/** + * A cache of the leaf files of partition directories. We cache these files in order to speed + * up iterated queries over the same set of partitions. Otherwise, each query would have to + * hit remote storage in order to gather file statistics for physical planning. + * + * Each resolved catalog table has its own FileStatusCache. When the backing relation for the + * table is refreshed via refreshTable() or refreshByPath(), this cache will be invalidated. + */ +abstract class FileStatusCache { + /** + * @return the leaf files for the specified path from this cache, or None if not cached. + */ + def getLeafFiles(path: Path): Option[Array[FileStatus]] = None + + /** + * Saves the given set of leaf files for a path in this cache. + */ + def putLeafFiles(path: Path, leafFiles: Array[FileStatus]): Unit + + /** + * Invalidates all data held by this cache. + */ + def invalidateAll(): Unit +} + +object FileStatusCache { + private var sharedCache: SharedInMemoryCache = null + + /** + * @return a new FileStatusCache based on session configuration. Cache memory quota is + * shared across all clients. + */ + def newCache(session: SparkSession): FileStatusCache = { + synchronized { + if (session.sqlContext.conf.manageFilesourcePartitions && + session.sqlContext.conf.filesourcePartitionFileCacheSize > 0) { + if (sharedCache == null) { + sharedCache = new SharedInMemoryCache( + session.sqlContext.conf.filesourcePartitionFileCacheSize) + } + sharedCache.getForNewClient() + } else { + NoopCache + } + } + } + + def resetForTesting(): Unit = synchronized { + sharedCache = null + } +} + +/** + * An implementation that caches partition file statuses in memory. + * + * @param maxSizeInBytes max allowable cache size before entries start getting evicted + */ +private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { + import FileStatusCache._ + + // Opaque object that uniquely identifies a shared cache user + private type ClientId = Object + + private val warnedAboutEviction = new AtomicBoolean(false) + + // we use a composite cache key in order to distinguish entries inserted by different clients + private val cache: Cache[(ClientId, Path), Array[FileStatus]] = CacheBuilder.newBuilder() + .weigher(new Weigher[(ClientId, Path), Array[FileStatus]] { + override def weigh(key: (ClientId, Path), value: Array[FileStatus]): Int = { + (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt + }}) + .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { + override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) = { + if (removed.getCause() == RemovalCause.SIZE && + warnedAboutEviction.compareAndSet(false, true)) { + logWarning( + "Evicting cached table partition metadata from memory due to size constraints " + + "(spark.sql.hive.filesourcePartitionFileCacheSize = " + maxSizeInBytes + " bytes). " + + "This may impact query planning performance.") + } + }}) + .maximumWeight(maxSizeInBytes) + .build() + + /** + * @return a FileStatusCache that does not share any entries with any other client, but does + * share memory resources for the purpose of cache eviction. + */ + def getForNewClient(): FileStatusCache = new FileStatusCache { + val clientId = new Object() + + override def getLeafFiles(path: Path): Option[Array[FileStatus]] = { + Option(cache.getIfPresent((clientId, path))) + } + + override def putLeafFiles(path: Path, leafFiles: Array[FileStatus]): Unit = { + cache.put((clientId, path), leafFiles.toArray) + } + + override def invalidateAll(): Unit = { + cache.asMap.asScala.foreach { case (key, value) => + if (key._1 == clientId) { + cache.invalidate(key) + } + } + } + } +} + +/** + * A non-caching implementation used when partition file status caching is disabled. + */ +object NoopCache extends FileStatusCache { + override def getLeafFiles(path: Path): Option[Array[FileStatus]] = None + override def putLeafFiles(path: Path, leafFiles: Array[FileStatus]): Unit = {} + override def invalidateAll(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala new file mode 100644 index 0000000000000..014abd454f5c0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister} +import org.apache.spark.sql.types.StructType + + +/** + * Acts as a container for all of the metadata required to read from a datasource. All discovery, + * resolution and merging logic for schemas and partitions has been removed. + * + * @param location A [[FileIndex]] that can enumerate the locations of all the files that + * comprise this relation. + * @param partitionSchema The schema of the columns (if any) that are used to partition the relation + * @param dataSchema The schema of any remaining columns. Note that if any partition columns are + * present in the actual data files as well, they are preserved. + * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). + * @param fileFormat A file format that can be used to read and write the data in files. + * @param options Configuration used when reading / writing data. + */ +case class HadoopFsRelation( + location: FileIndex, + partitionSchema: StructType, + dataSchema: StructType, + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String])(val sparkSession: SparkSession) + extends BaseRelation with FileRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionSchema.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } + + def partitionSchemaOption: Option[StructType] = + if (partitionSchema.isEmpty) None else Some(partitionSchema) + + override def toString: String = { + fileFormat match { + case source: DataSourceRegister => source.shortName() + case _ => "HadoopFiles" + } + } + + override def sizeInBytes: Long = location.sizeInBytes + + override def inputFiles: Array[String] = location.inputFiles +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala new file mode 100644 index 0000000000000..7531f0ae02e75 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs._ + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + + +/** + * A [[FileIndex]] that generates the list of files to process by recursively listing all the + * files present in `paths`. + * + * @param rootPaths the list of root table paths to scan + * @param parameters as set of options to control discovery + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions + */ +class InMemoryFileIndex( + sparkSession: SparkSession, + override val rootPaths: Seq[Path], + parameters: Map[String, String], + partitionSchema: Option[StructType], + fileStatusCache: FileStatusCache = NoopCache) + extends PartitioningAwareFileIndex( + sparkSession, parameters, partitionSchema, fileStatusCache) { + + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ + @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ + @volatile private var cachedPartitionSpec: PartitionSpec = _ + + refresh0() + + override def partitionSpec(): PartitionSpec = { + if (cachedPartitionSpec == null) { + cachedPartitionSpec = inferPartitioning() + } + logTrace(s"Partition spec: $cachedPartitionSpec") + cachedPartitionSpec + } + + override protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = { + cachedLeafFiles + } + + override protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { + cachedLeafDirToChildrenFiles + } + + override def refresh(): Unit = { + refresh0() + fileStatusCache.invalidateAll() + } + + private def refresh0(): Unit = { + val files = listLeafFiles(rootPaths) + cachedLeafFiles = + new mutable.LinkedHashMap[Path, FileStatus]() ++= files.map(f => f.getPath -> f) + cachedLeafDirToChildrenFiles = files.toArray.groupBy(_.getPath.getParent) + cachedPartitionSpec = null + } + + override def equals(other: Any): Boolean = other match { + case hdfs: InMemoryFileIndex => rootPaths.toSet == hdfs.rootPaths.toSet + case _ => false + } + + override def hashCode(): Int = rootPaths.toSet.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 99ca3df673568..22dbe7149531c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -20,18 +20,12 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.spark._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.internal.SQLConf /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -40,20 +34,6 @@ import org.apache.spark.sql.internal.SQLConf * implementation of [[HadoopFsRelation]] should use this UUID together with task id to generate * unique file path for each task output file. This UUID is passed to executor side via a * property named `spark.sql.sources.writeJobUUID`. - * - * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] - * are used to write to normal tables and tables with dynamic partitions. - * - * Basic work flow of this command is: - * - * 1. Driver side setup, including output committer initialization and data source specific - * preparation work for the write job to be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, @@ -103,52 +83,17 @@ case class InsertIntoHadoopFsRelationCommand( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - val job = Job.getInstance(hadoopConf) - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - - val partitionSet = AttributeSet(partitionColumns) - val dataColumns = query.output.filterNot(partitionSet.contains) - - val queryExecution = Dataset.ofRows(sparkSession, query).queryExecution - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - val relation = - WriteRelation( - sparkSession, - dataColumns.toStructType, - qualifiedOutputPath.toString, - fileFormat.prepareWrite(sparkSession, _, options, dataColumns.toStructType), - bucketSpec) - - val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) { - new DefaultWriterContainer(relation, job, isAppend) - } else { - new DynamicPartitionWriterContainer( - relation, - job, - partitionColumns = partitionColumns, - dataColumns = dataColumns, - inputSchema = query.output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - sparkSession.sessionState.conf.partitionMaxFiles, - isAppend) - } - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - sparkSession.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _) - writerContainer.commitJob() - refreshFunction() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - } + WriteOutput.write( + sparkSession, + query, + fileFormat, + qualifiedOutputPath, + hadoopConf, + partitionColumns, + bucketSpec, + refreshFunction, + options, + isAppend) } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala deleted file mode 100644 index 32532084236cf..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.io.FileNotFoundException - -import scala.collection.mutable - -import org.apache.hadoop.fs.{FileStatus, LocatedFileStatus, Path} -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.StructType - - -/** - * A [[FileCatalog]] that generates the list of files to process by recursively listing all the - * files present in `paths`. - * - * @param parameters as set of options to control discovery - * @param paths a list of paths to scan - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions - */ -class ListingFileCatalog( - sparkSession: SparkSession, - override val paths: Seq[Path], - parameters: Map[String, String], - partitionSchema: Option[StructType]) - extends PartitioningAwareFileCatalog(sparkSession, parameters, partitionSchema) { - - @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ - @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ - @volatile private var cachedPartitionSpec: PartitionSpec = _ - - refresh() - - override def partitionSpec(): PartitionSpec = { - if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning() - } - logTrace(s"Partition spec: $cachedPartitionSpec") - cachedPartitionSpec - } - - override protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = { - cachedLeafFiles - } - - override protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - cachedLeafDirToChildrenFiles - } - - override def refresh(): Unit = { - val files = listLeafFiles(paths) - cachedLeafFiles = - new mutable.LinkedHashMap[Path, FileStatus]() ++= files.map(f => f.getPath -> f) - cachedLeafDirToChildrenFiles = files.toArray.groupBy(_.getPath.getParent) - cachedPartitionSpec = null - } - - /** - * List leaf files of given paths. This method will submit a Spark job to do parallel - * listing whenever there is a path having more files than the parallel partition discovery - * discovery threshold. - * - * This is publicly visible for testing. - */ - def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession) - } else { - // Right now, the number of paths is less than the value of - // parallelPartitionDiscoveryThreshold. So, we will list file statues at the driver. - // If there is any child that has more files than the threshold, we will use parallel - // listing. - - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - - val statuses: Seq[FileStatus] = paths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - logTrace(s"Listing $path on driver") - - val childStatuses = { - try { - val stats = fs.listStatus(path) - if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats - } catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - } - - childStatuses.map { - case f: LocatedFileStatus => f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - if (f.isDirectory ) { - // If f is a directory, we do not need to call getFileBlockLocations (SPARK-14959). - f - } else { - HadoopFsRelation.createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) - } - } - }.filterNot { status => - val name = status.getPath.getName - HadoopFsRelation.shouldFilterOut(name) - } - - val (dirs, files) = statuses.partition(_.isDirectory) - - // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) - if (dirs.isEmpty) { - mutable.LinkedHashSet(files: _*) - } else { - mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) - } - } - } - - override def equals(other: Any): Boolean = other match { - case hdfs: ListingFileCatalog => paths.toSet == hdfs.paths.toSet - case _ => false - } - - override def hashCode(): Int = paths.toSet.hashCode() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index d9562fd32e87d..7c28d48f26416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -94,7 +94,7 @@ case class LogicalRelation( } override def refresh(): Unit = relation match { - case fs: HadoopFsRelation => fs.refresh() + case fs: HadoopFsRelation => fs.location.refresh() case _ => // Do nothing. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala new file mode 100644 index 0000000000000..fbf6e96d3f850 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.types.StructType + + +/** + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. + */ +abstract class OutputWriterFactory extends Serializable { + /** + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. + * + * @param stagingDir Base path (directory) of the file to which this [[OutputWriter]] is supposed + * to write. Note that this may not point to the final output file. For + * example, `FileOutputFormat` writes to temporary directories and then merge + * written files back to the final destination. In this case, `path` points to + * a temporary output file under the temporary directory. + * @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this + * prefix is used in the actual file name. For example, if the prefix is + * "part-1-2-3", then the file name must start with "part_1_2_3" but can + * end in arbitrary extension that is deterministic given the configuration + * (i.e. the suffix extension should not depend on any task id, attempt id, + * or partition id). + * @param dataSchema Schema of the rows to be written. Partition columns are not included in the + * schema if the relation being written is partitioned. + * @param context The Hadoop MapReduce task context. + */ + def newInstance( + stagingDir: String, + fileNamePrefix: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter + + /** + * Returns a new instance of [[OutputWriter]] that will write data to the given path. + * This method gets called by each task on executor to write InternalRows to + * format-specific files. Compared to the other `newInstance()`, this is a newer API that + * passes only the path that the writer must write to. The writer must write to the exact path + * and not modify it (do not add subdirectories, extensions, etc.). All other + * file-format-specific information needed to create the writer must be passed + * through the [[OutputWriterFactory]] implementation. + */ + def newWriter(path: String): OutputWriter = { + throw new UnsupportedOperationException("newInstance with just path not supported") + } +} + + +/** + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + */ +abstract class OutputWriter { + + /** + * The path of the file to be written out. This path should include the staging directory and + * the file name prefix passed into the associated createOutputWriter function. + */ + def path: String + + /** + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * tables, dynamic partition columns are not included in rows to be written. + */ + def write(row: Row): Unit + + /** + * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before + * the task output is committed. + */ + def close(): Unit + + private var converter: InternalRow => Row = _ + + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } + + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala deleted file mode 100644 index 702ba97222e34..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import scala.collection.mutable - -import org.apache.hadoop.fs.{FileStatus, Path} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, StructType} - - -/** - * An abstract class that represents [[FileCatalog]]s that are aware of partitioned tables. - * It provides the necessary methods to parse partition data based on a set of files. - * - * @param parameters as set of options to control partition discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions -*/ -abstract class PartitioningAwareFileCatalog( - sparkSession: SparkSession, - parameters: Map[String, String], - partitionSchema: Option[StructType]) - extends FileCatalog with Logging { - - protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) - - protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] - - protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] - - override def listFiles(filters: Seq[Expression]): Seq[Partition] = { - val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { - Partition(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil - } else { - prunePartitions(filters, partitionSpec()).map { - case PartitionDirectory(values, path) => - val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { - case Some(existingDir) => - // Directory has children files in it, return them - existingDir.filter(f => isDataPath(f.getPath)) - - case None => - // Directory does not exist, or has no children files - Nil - } - Partition(values, files) - } - } - logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) - selectedPartitions - } - - override def allFiles(): Seq[FileStatus] = { - if (partitionSpec().partitionColumns.isEmpty) { - // For each of the input paths, get the list of files inside them - paths.flatMap { path => - // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). - val fs = path.getFileSystem(hadoopConf) - val qualifiedPathPre = fs.makeQualified(path) - val qualifiedPath: Path = if (qualifiedPathPre.isRoot && !qualifiedPathPre.isAbsolute) { - // SPARK-17613: Always append `Path.SEPARATOR` to the end of parent directories, - // because the `leafFile.getParent` would have returned an absolute path with the - // separator at the end. - new Path(qualifiedPathPre, Path.SEPARATOR) - } else { - qualifiedPathPre - } - - // There are three cases possible with each path - // 1. The path is a directory and has children files in it. Then it must be present in - // leafDirToChildrenFiles as those children files will have been found as leaf files. - // Find its children files from leafDirToChildrenFiles and include them. - // 2. The path is a file, then it will be present in leafFiles. Include this path. - // 3. The path is a directory, but has no children files. Do not include this path. - - leafDirToChildrenFiles.get(qualifiedPath) - .orElse { leafFiles.get(qualifiedPath).map(Array(_)) } - .getOrElse(Array.empty) - } - } else { - leafFiles.values.toSeq - } - } - - protected def inferPartitioning(): PartitionSpec = { - // We use leaf dirs containing data files to discover the schema. - val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => - // SPARK-15895: Metadata files (e.g. Parquet summary files) and temporary files should not be - // counted as data files, so that they shouldn't participate partition discovery. - files.exists(f => isDataPath(f.getPath)) - }.keys.toSeq - partitionSchema match { - case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = false, - basePaths = basePaths) - - // Without auto inference, all of value in the `row` should be null or in StringType, - // we need to cast into the data type that user specified. - def castPartitionValuesToUserSchema(row: InternalRow) = { - InternalRow((0 until row.numFields).map { i => - Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType).eval() - }: _*) - } - - PartitionSpec(userProvidedSchema, spec.partitions.map { part => - part.copy(values = castPartitionValuesToUserSchema(part.values)) - }) - case _ => - PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths) - } - } - - private def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = partitionPruningPredicates.reduce(expressions.And) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - val selected = partitions.filter { - case PartitionDirectory(values, _) => boundPredicate(values) - } - logInfo { - val total = partitions.length - val selectedSize = selected.length - val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." - } - - selected - } else { - partitions - } - } - - /** - * Contains a set of paths that are considered as the base dirs of the input datasets. - * The partitioning discovery logic will make sure it will stop when it reaches any - * base path. - * - * By default, the paths of the dataset provided by users will be base paths. - * Below are three typical examples, - * Case 1) `spark.read.parquet("/path/something=true/")`: the base path will be - * `/path/something=true/`, and the returned DataFrame will not contain a column of `something`. - * Case 2) `spark.read.parquet("/path/something=true/a.parquet")`: the base path will be - * still `/path/something=true/`, and the returned DataFrame will also not contain a column of - * `something`. - * Case 3) `spark.read.parquet("/path/")`: the base path will be `/path/`, and the returned - * DataFrame will have the column of `something`. - * - * Users also can override the basePath by setting `basePath` in the options to pass the new base - * path to the data source. - * For example, `spark.read.option("basePath", "/path/").parquet("/path/something=true/")`, - * and the returned DataFrame will have the column of `something`. - */ - private def basePaths: Set[Path] = { - parameters.get("basePath").map(new Path(_)) match { - case Some(userDefinedBasePath) => - val fs = userDefinedBasePath.getFileSystem(hadoopConf) - if (!fs.isDirectory(userDefinedBasePath)) { - throw new IllegalArgumentException("Option 'basePath' must be a directory") - } - Set(fs.makeQualified(userDefinedBasePath)) - - case None => - paths.map { path => - // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). - val qualifiedPath = path.getFileSystem(hadoopConf).makeQualified(path) - if (leafFiles.contains(qualifiedPath)) qualifiedPath.getParent else qualifiedPath }.toSet - } - } - - private def isDataPath(path: Path): Boolean = { - val name = path.getName - !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala new file mode 100644 index 0000000000000..a8a722dd3c620 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.FileNotFoundException + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. + * It provides the necessary methods to parse partition data based on a set of files. + * + * @param parameters as set of options to control partition discovery + * @param userPartitionSchema an optional partition schema that will be use to provide types for + * the discovered partitions + */ +abstract class PartitioningAwareFileIndex( + sparkSession: SparkSession, + parameters: Map[String, String], + userPartitionSchema: Option[StructType], + fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { + import PartitioningAwareFileIndex.BASE_PATH_PARAM + + /** Returns the specification of the partitions inferred from the data. */ + def partitionSpec(): PartitionSpec + + override def partitionSchema: StructType = partitionSpec().partitionColumns + + protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + + protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] + + protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] + + override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { + val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { + PartitionDirectory(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionPath(values, path) => + val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { + case Some(existingDir) => + // Directory has children files in it, return them + existingDir.filter(f => isDataPath(f.getPath)) + + case None => + // Directory does not exist, or has no children files + Nil + } + PartitionDirectory(values, files) + } + } + logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) + selectedPartitions + } + + /** Returns the list of files that will be read when scanning this relation. */ + override def inputFiles: Array[String] = + allFiles().map(_.getPath.toUri.toString).toArray + + override def sizeInBytes: Long = allFiles().map(_.getLen).sum + + def allFiles(): Seq[FileStatus] = { + if (partitionSpec().partitionColumns.isEmpty) { + // For each of the root input paths, get the list of files inside them + rootPaths.flatMap { path => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = path.getFileSystem(hadoopConf) + val qualifiedPathPre = fs.makeQualified(path) + val qualifiedPath: Path = if (qualifiedPathPre.isRoot && !qualifiedPathPre.isAbsolute) { + // SPARK-17613: Always append `Path.SEPARATOR` to the end of parent directories, + // because the `leafFile.getParent` would have returned an absolute path with the + // separator at the end. + new Path(qualifiedPathPre, Path.SEPARATOR) + } else { + qualifiedPathPre + } + + // There are three cases possible with each path + // 1. The path is a directory and has children files in it. Then it must be present in + // leafDirToChildrenFiles as those children files will have been found as leaf files. + // Find its children files from leafDirToChildrenFiles and include them. + // 2. The path is a file, then it will be present in leafFiles. Include this path. + // 3. The path is a directory, but has no children files. Do not include this path. + + leafDirToChildrenFiles.get(qualifiedPath) + .orElse { leafFiles.get(qualifiedPath).map(Array(_)) } + .getOrElse(Array.empty) + } + } else { + leafFiles.values.toSeq + } + } + + protected def inferPartitioning(): PartitionSpec = { + // We use leaf dirs containing data files to discover the schema. + val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => + files.exists(f => isDataPath(f.getPath)) + }.keys.toSeq + userPartitionSchema match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getUTF8String(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + case _ => + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths) + } + } + + private def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionPath] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = partitionPruningPredicates.reduce(expressions.And) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionPath(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + + /** + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. + * + * By default, the paths of the dataset provided by users will be base paths. + * Below are three typical examples, + * Case 1) `spark.read.parquet("/path/something=true/")`: the base path will be + * `/path/something=true/`, and the returned DataFrame will not contain a column of `something`. + * Case 2) `spark.read.parquet("/path/something=true/a.parquet")`: the base path will be + * still `/path/something=true/`, and the returned DataFrame will also not contain a column of + * `something`. + * Case 3) `spark.read.parquet("/path/")`: the base path will be `/path/`, and the returned + * DataFrame will have the column of `something`. + * + * Users also can override the basePath by setting `basePath` in the options to pass the new base + * path to the data source. + * For example, `spark.read.option("basePath", "/path/").parquet("/path/something=true/")`, + * and the returned DataFrame will have the column of `something`. + */ + private def basePaths: Set[Path] = { + parameters.get(BASE_PATH_PARAM).map(new Path(_)) match { + case Some(userDefinedBasePath) => + val fs = userDefinedBasePath.getFileSystem(hadoopConf) + if (!fs.isDirectory(userDefinedBasePath)) { + throw new IllegalArgumentException(s"Option '$BASE_PATH_PARAM' must be a directory") + } + Set(fs.makeQualified(userDefinedBasePath)) + + case None => + rootPaths.map { path => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val qualifiedPath = path.getFileSystem(hadoopConf).makeQualified(path) + if (leafFiles.contains(qualifiedPath)) qualifiedPath.getParent else qualifiedPath }.toSet + } + } + + // SPARK-15895: Metadata files (e.g. Parquet summary files) and temporary files should not be + // counted as data files, so that they shouldn't participate partition discovery. + private def isDataPath(path: Path): Boolean = { + val name = path.getName + !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) + } + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + val output = mutable.LinkedHashSet[FileStatus]() + val pathsToFetch = mutable.ArrayBuffer[Path]() + for (path <- paths) { + fileStatusCache.getLeafFiles(path) match { + case Some(files) => + HiveCatalogMetrics.incrementFileCacheHits(files.length) + output ++= files + case None => + pathsToFetch += path + } + } + val discovered = if (pathsToFetch.length >= + sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + PartitioningAwareFileIndex.listLeafFilesInParallel(pathsToFetch, hadoopConf, sparkSession) + } else { + PartitioningAwareFileIndex.listLeafFilesInSerial(pathsToFetch, hadoopConf) + } + discovered.foreach { case (path, leafFiles) => + HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) + fileStatusCache.putLeafFiles(path, leafFiles.toArray) + output ++= leafFiles + } + output + } +} + +object PartitioningAwareFileIndex extends Logging { + val BASE_PATH_PARAM = "basePath" + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * List a collection of path recursively. + */ + private def listLeafFilesInSerial( + paths: Seq[Path], + hadoopConf: Configuration): Seq[(Path, Seq[FileStatus])] = { + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass) + val filter = FileInputFormat.getInputPathFilter(jobConf) + + paths.map { path => + val fs = path.getFileSystem(hadoopConf) + (path, listLeafFiles0(fs, path, filter)) + } + } + + /** + * List a collection of path recursively in parallel (using Spark executors). + * Each task launched will use [[listLeafFilesInSerial]] to list. + */ + private def listLeafFilesInParallel( + paths: Seq[Path], + hadoopConf: Configuration, + sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { + assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, 10000) + + val statusMap = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + val hadoopConf = serializableConfiguration.value + listLeafFilesInSerial(paths.map(new Path(_)).toSeq, hadoopConf).iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } + } + + /** + * List a single path, provided as a FileStatus, in serial. + */ + private def listLeafFiles0( + fs: FileSystem, path: Path, filter: PathFilter): Seq[FileStatus] = { + logTrace(s"Listing $path") + val name = path.getName.toLowerCase + if (shouldFilterOut(name)) { + Seq.empty[FileStatus] + } else { + // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val allLeafStatuses = { + val (dirs, files) = statuses.partition(_.isDirectory) + val stats = files ++ dirs.flatMap(dir => listLeafFiles0(fs, dir.getPath, filter)) + if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } + } + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && + !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 504464216e5a4..f66e8b4e2b551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date => JDate, Timestamp => JTimestamp} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -28,13 +29,14 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. -object PartitionDirectory { - def apply(values: InternalRow, path: String): PartitionDirectory = +object PartitionPath { + def apply(values: InternalRow, path: String): PartitionPath = apply(values, new Path(path)) } @@ -42,14 +44,14 @@ object PartitionDirectory { * Holds a directory in a partitioned collection of files as well as as the partition values * in the form of a Row. Before scanning, the files at `path` need to be enumerated. */ -case class PartitionDirectory(values: InternalRow, path: Path) +case class PartitionPath(values: InternalRow, path: Path) case class PartitionSpec( partitionColumns: StructType, - partitions: Seq[PartitionDirectory]) + partitions: Seq[PartitionPath]) object PartitionSpec { - val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionPath]) } object PartitioningUtils { @@ -141,7 +143,7 @@ object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path) + PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path) } PartitionSpec(StructType(fields), partitions) @@ -242,6 +244,35 @@ object PartitioningUtils { } } + /** + * Normalize the column names in partition specification, w.r.t. the real partition column names + * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a + * partition column named `month`, and it's case insensitive, we will normalize `monTh` to + * `month`. + */ + def normalizePartitionSpec[T]( + partitionSpec: Map[String, T], + partColNames: Seq[String], + tblName: String, + resolver: Resolver): Map[String, T] = { + val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => + val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { + throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") + } + normalizedKey -> value + } + + if (normalizedPartSpec.map(_._1).distinct.length != normalizedPartSpec.length) { + val duplicateColumns = normalizedPartSpec.map(_._1).groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicated columns in partition specification: " + + duplicateColumns.mkString(", ")) + } + + normalizedPartSpec.toMap + } + /** * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- * casting order is: @@ -307,20 +338,34 @@ object PartitioningUtils { /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and - * [[StringType]]. + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]] + * [[TimestampType]], and [[StringType]]. */ private[datasources] def inferPartitionColumnValue( raw: String, defaultPartitionName: String, typeInference: Boolean): Literal = { + val decimalTry = Try { + // `BigDecimal` conversion can fail when the `field` is not a form of number. + val bigDecimal = new JBigDecimal(raw) + // It reduces the cases for decimals by disallowing values having scale (eg. `1.1`). + require(bigDecimal.scale <= 0) + // `DecimalType` conversion can fail when + // 1. The precision is bigger than 38. + // 2. scale is bigger than precision. + Literal(bigDecimal) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + .orElse(decimalTry) // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal(new JBigDecimal(raw)))) + // Then falls back to date/timestamp types + .orElse(Try(Literal(JDate.valueOf(raw)))) + .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) // Then falls back to string .getOrElse { if (raw == defaultPartitionName) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala new file mode 100644 index 0000000000000..8566a8061034b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case op @ PhysicalOperation(projects, filters, + logicalRelation @ + LogicalRelation(fsRelation @ + HadoopFsRelation( + catalogFileIndex: CatalogFileIndex, + partitionSchema, + _, + _, + _, + _), + _, + _)) + if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => + // The attribute name of predicate could be different than the one in schema in case of + // case insensitive, we should change them to match the one in schema, so we donot need to + // worry about case sensitivity anymore. + val normalizedFilters = filters.map { e => + e transform { + case a: AttributeReference => + a.withName(logicalRelation.output.find(_.semanticEquals(a)).get.name) + } + } + + val sparkSession = fsRelation.sparkSession + val partitionColumns = + logicalRelation.resolve( + partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + val partitionKeyFilters = + ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + + if (partitionKeyFilters.nonEmpty) { + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFsRelation = + fsRelation.copy(location = prunedFileIndex)(sparkSession) + val prunedLogicalRelation = logicalRelation.copy( + relation = prunedFsRelation, + expectedOutputAttributes = Some(logicalRelation.output)) + + // Keep partition-pruning predicates so that they are visible in physical planning + val filterExpression = filters.reduceLeft(And) + val filter = Filter(filterExpression, prunedLogicalRelation) + Project(projects, filter) + } else { + op + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala new file mode 100644 index 0000000000000..bd56e511d0ccf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{SQLExecution, UnsafeKVExternalSorter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + + +/** A helper object for writing data out to a location. */ +object WriteOutput extends Logging { + + /** A shared job description for all the write tasks. */ + private class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val nonPartitionColumns: Seq[Attribute], + val bucketSpec: Option[BucketSpec], + val isAppend: Boolean, + val path: String, + val outputFormatClass: Class[_ <: OutputFormat[_, _]]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Non-partition columns: ${nonPartitionColumns.mkString(", ")} + """.stripMargin) + } + + /** + * Basic work flow of this command is: + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write( + sparkSession: SparkSession, + plan: LogicalPlan, + fileFormat: FileFormat, + outputPath: Path, + hadoopConf: Configuration, + partitionColumns: Seq[Attribute], + bucketSpec: Option[BucketSpec], + refreshFunction: () => Unit, + options: Map[String, String], + isAppend: Boolean): Unit = { + + val job = Job.getInstance(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, outputPath) + + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = plan.output.filterNot(partitionSet.contains) + val queryExecution = Dataset.ofRows(sparkSession, plan).queryExecution + + // Note: prepareWrite has side effect. It sets "job". + val outputWriterFactory = + fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) + + val description = new WriteJobDescription( + uuid = UUID.randomUUID().toString, + serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + allColumns = plan.output, + partitionColumns = partitionColumns, + nonPartitionColumns = dataColumns, + bucketSpec = bucketSpec, + isAppend = isAppend, + path = outputPath.toString, + outputFormatClass = job.getOutputFormatClass) + + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + val committer = setupDriverCommitter(job, outputPath.toString, isAppend) + + try { + sparkSession.sparkContext.runJob(queryExecution.toRdd, + (taskContext: TaskContext, iter: Iterator[InternalRow]) => { + executeTask( + description = description, + sparkStageId = taskContext.stageId(), + sparkPartitionId = taskContext.partitionId(), + sparkAttemptNumber = taskContext.attemptNumber(), + iterator = iter) + }) + + committer.commitJob(job) + logInfo(s"Job ${job.getJobID} committed.") + refreshFunction() + } catch { case cause: Throwable => + logError(s"Aborting job ${job.getJobID}.", cause) + committer.abortJob(job, JobStatus.State.FAILED) + throw new SparkException("Job aborted.", cause) + } + } + } + + /** Writes data out in a single Spark task. */ + private def executeTask( + description: WriteJobDescription, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + iterator: Iterator[InternalRow]): Unit = { + + val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) + + // Set up the attempt context required to use in the output committer. + val taskAttemptContext: TaskAttemptContext = { + // Set up the configuration object + val hadoopConf = description.serializableHadoopConf.value + hadoopConf.set("mapred.job.id", jobId.toString) + hadoopConf.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapred.task.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapred.task.is.map", true) + hadoopConf.setInt("mapred.task.partition", 0) + + new TaskAttemptContextImpl(hadoopConf, taskAttemptId) + } + + val committer = newOutputCommitter( + description.outputFormatClass, taskAttemptContext, description.path, description.isAppend) + committer.setupTask(taskAttemptContext) + + // Figure out where we need to write data to for staging. + // For FileOutputCommitter it has its own staging path called "work path". + val stagingPath = committer match { + case f: FileOutputCommitter => f.getWorkPath.toString + case _ => description.path + } + + val writeTask = + if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + new SingleDirectoryWriteTask(description, taskAttemptContext, stagingPath) + } else { + new DynamicPartitionWriteTask(description, taskAttemptContext, stagingPath) + } + + try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + // Execute the task to write rows out + writeTask.execute(iterator) + writeTask.releaseResources() + + // Commit the task + SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId) + })(catchBlock = { + // If there is an error, release resource and then abort the task + try { + writeTask.releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + logError(s"Job $jobId aborted.") + } + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } + + /** + * A simple trait for writing out data in a single Spark task, without any concerns about how + * to commit or abort tasks. Exceptions thrown by the implementation of this trait will + * automatically trigger task aborts. + */ + private trait ExecuteWriteTask { + def execute(iterator: Iterator[InternalRow]): Unit + def releaseResources(): Unit + + final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = { + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + f"part-r-$split%05d-$uuid$bucketString" + } + } + + /** Writes data to a single directory (used for non-dynamic-partition writes). */ + private class SingleDirectoryWriteTask( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + stagingPath: String) extends ExecuteWriteTask { + + private[this] var outputWriter: OutputWriter = { + val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + + val outputWriter = description.outputWriterFactory.newInstance( + stagingDir = stagingPath, + fileNamePrefix = filePrefix(split, description.uuid, None), + dataSchema = description.nonPartitionColumns.toStructType, + context = taskAttemptContext) + outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) + outputWriter + } + + override def execute(iter: Iterator[InternalRow]): Unit = { + while (iter.hasNext) { + val internalRow = iter.next() + outputWriter.writeInternal(internalRow) + } + } + + override def releaseResources(): Unit = { + if (outputWriter != null) { + outputWriter.close() + outputWriter = null + } + } + } + + /** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ + private class DynamicPartitionWriteTask( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + stagingPath: String) extends ExecuteWriteTask { + + // currentWriter is initialized whenever we see a new key + private var currentWriter: OutputWriter = _ + + private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) + } + + private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) + } + + private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + + /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ + private def partitionStringExpression: Seq[Expression] = { + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = ScalaUDF( + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) + val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + } + } + + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ + private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { + val path = + if (description.partitionColumns.nonEmpty) { + val partitionPath = partString(key).getString(0) + new Path(stagingPath, partitionPath).toString + } else { + stagingPath + } + + // If the bucket spec is defined, the bucket column is right after the partition columns + val bucketId = if (description.bucketSpec.isDefined) { + Some(key.getInt(description.partitionColumns.length)) + } else { + None + } + + val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + val newWriter = description.outputWriterFactory.newInstance( + stagingDir = path, + fileNamePrefix = filePrefix(split, description.uuid, bucketId), + dataSchema = description.nonPartitionColumns.toStructType, + context = taskAttemptContext) + newWriter.initConverter(description.nonPartitionColumns.toStructType) + newWriter + } + + override def execute(iter: Iterator[InternalRow]): Unit = { + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val sortingExpressions: Seq[Expression] = + description.partitionColumns ++ bucketIdExpression ++ sortColumns + val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create( + description.nonPartitionColumns, description.allColumns) + + // Returns the partition path given a partition key. + val getPartitionString = UnsafeProjection.create( + Seq(Concat(partitionStringExpression)), description.partitionColumns) + + // Sorts the data before write, so that we only need one writer at the same time. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(description.nonPartitionColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) + + while (iter.hasNext) { + val currentRow = iter.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + + val sortedIterator = sorter.sortedIterator() + + // If anything below fails, we should abort the task. + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } + + override def releaseResources(): Unit = { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } + } + + private def setupDriverCommitter(job: Job, path: String, isAppend: Boolean): OutputCommitter = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + job.getConfiguration.set("mapred.job.id", jobId.toString) + job.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + job.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + job.getConfiguration.setBoolean("mapred.task.is.map", true) + job.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId) + val outputCommitter = newOutputCommitter( + job.getOutputFormatClass, taskAttemptContext, path, isAppend) + outputCommitter.setupJob(job) + outputCommitter + } + + private def newOutputCommitter( + outputFormatClass: Class[_ <: OutputFormat[_, _]], + context: TaskAttemptContext, + path: String, + isAppend: Boolean): OutputCommitter = { + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the appending job fails. + // See SPARK-8578 for more details + logInfo( + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val configuration = context.getConfiguration + val clazz = + configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + if (clazz != null) { + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(path), context) + } else { + // The specified output committer is just an OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + } else { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala deleted file mode 100644 index 7880c7cfa16f8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ /dev/null @@ -1,456 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.util.{Date, UUID} - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.UnsafeKVExternalSorter -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter - - -/** A container for all the details required when writing to a table. */ -private[datasources] case class WriteRelation( - sparkSession: SparkSession, - dataSchema: StructType, - path: String, - prepareJobForWrite: Job => OutputWriterFactory, - bucketSpec: Option[BucketSpec]) - -object WriterContainer { - val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID" - val DATASOURCE_OUTPUTPATH = "spark.sql.sources.output.path" -} - -private[datasources] abstract class BaseWriterContainer( - @transient val relation: WriteRelation, - @transient private val job: Job, - isAppend: Boolean) - extends Logging with Serializable { - - protected val dataSchema = relation.dataSchema - - protected val serializableConf = - new SerializableConfiguration(job.getConfiguration) - - // This UUID is used to avoid output file name collision between different appending write jobs. - // These jobs may belong to different SparkContext instances. Concrete data source implementations - // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). - // The reason why this ID is used to identify a job rather than a single task output file is - // that, speculative tasks must generate the same output file name as the original task. - private val uniqueWriteJobId = UUID.randomUUID() - - // This is only used on driver side. - @transient private val jobContext: JobContext = job - - private val speculationEnabled: Boolean = - relation.sparkSession.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) - - // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: OutputCommitter = _ - @transient private var jobId: JobID = _ - @transient private var taskId: TaskID = _ - @transient private var taskAttemptId: TaskAttemptID = _ - @transient protected var taskAttemptContext: TaskAttemptContext = _ - - protected val outputPath: String = relation.path - - protected var outputWriterFactory: OutputWriterFactory = _ - - private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit - - def driverSideSetup(): Unit = { - setupIDs(0, 0, 0) - setupConf() - - // This UUID is sent to executor side together with the serialized `Configuration` object within - // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate - // unique task output files. - job.getConfiguration.set(WriterContainer.DATASOURCE_WRITEJOBUUID, uniqueWriteJobId.toString) - - // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor - // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, - // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. - // - // Also, the `prepareJobForWrite` call must happen before initializing output format and output - // committer, since their initialization involve the job configuration, which can be potentially - // decorated in `prepareJobForWrite`. - outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) - - outputFormatClass = job.getOutputFormatClass - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupJob(jobContext) - } - - def executorSideSetup(taskContext: TaskContext): Unit = { - setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) - setupConf() - taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupTask(taskAttemptContext) - } - - protected def getWorkPath: String = { - outputCommitter match { - // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: MapReduceFileOutputCommitter => f.getWorkPath.toString - case _ => outputPath - } - } - - protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { - try { - outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) - } catch { - case e: org.apache.hadoop.fs.FileAlreadyExistsException => - if (outputCommitter.getClass.getName.contains("Direct")) { - // SPARK-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry - // attempts, the task will fail because the output file is created from a prior attempt. - // This often means the most visible error to the user is misleading. Augment the error - // to tell the user to look for the actual error. - throw new SparkException("The output file already exists but this could be due to a " + - "failure from an earlier attempt. Look through the earlier logs or stage page for " + - "the first error.\n File exists error: " + e, e) - } else { - throw e - } - } - } - - private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the appending job fails. - // - // See SPARK-8578 for more details - logInfo( - s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val configuration = context.getConfiguration - val committerClass = configuration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just an OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } - - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { - this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId) - this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - } - - private def setupConf(): Unit = { - serializableConf.value.set("mapred.job.id", jobId.toString) - serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - serializableConf.value.set("mapred.task.id", taskAttemptId.toString) - serializableConf.value.setBoolean("mapred.task.is.map", true) - serializableConf.value.setInt("mapred.task.partition", 0) - } - - def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) - } - - def abortTask(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortTask(taskAttemptContext) - } - logError(s"Task attempt $taskAttemptId aborted.") - } - - def commitJob(): Unit = { - outputCommitter.commitJob(jobContext) - logInfo(s"Job $jobId committed.") - } - - def abortJob(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) - } - logError(s"Job $jobId aborted.") - } -} - -/** - * A writer that writes all of the rows in a partition to a single file. - */ -private[datasources] class DefaultWriterContainer( - relation: WriteRelation, - job: Job, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - executorSideSetup(taskContext) - val configuration = taskAttemptContext.getConfiguration - configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath) - var writer = newOutputWriter(getWorkPath) - writer.initConverter(dataSchema) - - // If anything below fails, we should abort the task. - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - commitTask() - }(catchBlock = abortTask()) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - - def commitTask(): Unit = { - try { - if (writer != null) { - writer.close() - writer = null - } - super.commitTask() - } catch { - case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and - // will cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) - } - } - - def abortTask(): Unit = { - try { - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() - } - } - } -} - -/** - * A writer that dynamically opens files based on the given partition columns. Internally this is - * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the - * writer externally sorts the remaining rows and then writes out them out one file at a time. - */ -private[datasources] class DynamicPartitionWriterContainer( - relation: WriteRelation, - job: Job, - partitionColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - inputSchema: Seq[Attribute], - defaultPartitionName: String, - maxOpenFiles: Int, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - private val bucketSpec = relation.bucketSpec - - private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) - } - - private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) - } - - private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - - // Expressions that given a partition key build a string like: col1=val/col2=val/... - private def partitionStringExpression: Seq[Expression] = { - partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = - ScalaUDF( - PartitioningUtils.escapePathName _, - StringType, - Seq(Cast(c, StringType)), - Seq(StringType)) - val str = If(IsNull(c), Literal(defaultPartitionName), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName - } - } - - private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ => - key.getInt(partitionColumns.length) - } - - /** - * Open and returns a new OutputWriter given a partition key and optional bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - */ - private def newOutputWriter( - key: InternalRow, - getPartitionString: UnsafeProjection): OutputWriter = { - val configuration = taskAttemptContext.getConfiguration - val path = if (partitionColumns.nonEmpty) { - val partitionPath = getPartitionString(key).getString(0) - configuration.set( - WriterContainer.DATASOURCE_OUTPUTPATH, - new Path(outputPath, partitionPath).toString) - new Path(getWorkPath, partitionPath).toString - } else { - configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath) - getWorkPath - } - val bucketId = getBucketIdFromKey(key) - val newWriter = super.newOutputWriter(path, bucketId) - newWriter.initConverter(dataSchema) - newWriter - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - executorSideSetup(taskContext) - - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) - - // Returns the partition path given a partition key. - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } - - val sortedIterator = sorter.sortedIterator() - - // If anything below fails, we should abort the task. - var currentWriter: OutputWriter = null - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - - commitTask() - }(catchBlock = { - if (currentWriter != null) { - currentWriter.close() - } - abortTask() - }) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 55cb26d6513af..a35cfdb2c234f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.types._ object CSVRelation extends Logging { @@ -170,21 +171,26 @@ object CSVRelation extends Logging { private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - if (bucketId.isDefined) sys.error("csv doesn't support bucketing") - new CsvOutputWriter(path, dataSchema, context, params) + new CsvOutputWriter(stagingDir, fileNamePrefix, dataSchema, context, params) } } private[csv] class CsvOutputWriter( - path: String, + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".csv" + compressionExtension).toString + } + // create the Generator without separator inserted between 2 records private[this] val text = new Text() @@ -199,11 +205,7 @@ private[csv] class CsvOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension") + new Path(path) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala deleted file mode 100644 index a8f693de7b1a2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ /dev/null @@ -1,518 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import scala.collection.mutable - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ -import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} - -import org.apache.spark.annotation.Experimental -import org.apache.spark.internal.Logging -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, Filter} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration - -/** - * ::Experimental:: - * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver - * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized - * to executor side to create actual [[OutputWriter]]s on the fly. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriterFactory extends Serializable { - /** - * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side - * to instantiate new [[OutputWriter]]s. - * - * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that - * this may not point to the final output file. For example, `FileOutputFormat` writes to - * temporary directories and then merge written files back to the final destination. In - * this case, `path` points to a temporary output file under the temporary directory. - * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the relation being written is partitioned. - * @param context The Hadoop MapReduce task context. - * @since 1.4.0 - */ - def newInstance( - path: String, - bucketId: Option[Int], // TODO: This doesn't belong here... - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter - - /** - * Returns a new instance of [[OutputWriter]] that will write data to the given path. - * This method gets called by each task on executor to write [[InternalRow]]s to - * format-specific files. Compared to the other `newInstance()`, this is a newer API that - * passes only the path that the writer must write to. The writer must write to the exact path - * and not modify it (do not add subdirectories, extensions, etc.). All other - * file-format-specific information needed to create the writer must be passed - * through the [[OutputWriterFactory]] implementation. - * @since 2.0.0 - */ - def newWriter(path: String): OutputWriter = { - throw new UnsupportedOperationException("newInstance with just path not supported") - } -} - -/** - * ::Experimental:: - * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriter { - /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned - * tables, dynamic partition columns are not included in rows to be written. - * - * @since 1.4.0 - */ - def write(row: Row): Unit - - /** - * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before - * the task output is committed. - * - * @since 1.4.0 - */ - def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } -} - -/** - * Acts as a container for all of the metadata required to read from a datasource. All discovery, - * resolution and merging logic for schemas and partitions has been removed. - * - * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise - * this relation. - * @param partitionSchema The schema of the columns (if any) that are used to partition the relation - * @param dataSchema The schema of any remaining columns. Note that if any partition columns are - * present in the actual data files as well, they are preserved. - * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). - * @param fileFormat A file format that can be used to read and write the data in files. - * @param options Configuration used when reading / writing data. - */ -case class HadoopFsRelation( - location: FileCatalog, - partitionSchema: StructType, - dataSchema: StructType, - bucketSpec: Option[BucketSpec], - fileFormat: FileFormat, - options: Map[String, String])(val sparkSession: SparkSession) - extends BaseRelation with FileRelation { - - override def sqlContext: SQLContext = sparkSession.sqlContext - - val schema: StructType = { - val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSchema.filterNot { column => - dataSchemaColumnNames.contains(column.name.toLowerCase) - }) - } - - def partitionSchemaOption: Option[StructType] = - if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec() - - def refresh(): Unit = location.refresh() - - override def toString: String = { - fileFormat match { - case source: DataSourceRegister => source.shortName() - case _ => "HadoopFiles" - } - } - - /** Returns the list of files that will be read when scanning this relation. */ - override def inputFiles: Array[String] = - location.allFiles().map(_.getPath.toUri.toString).toArray - - override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum -} - -/** - * Used to read and write data stored in files to/from the [[InternalRow]] format. - */ -trait FileFormat { - /** - * When possible, this method should return the schema of the given `files`. When the format - * does not support inference, or no valid files are given should return None. In these cases - * Spark will require that user specify the schema manually. - */ - def inferSchema( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] - - /** - * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can - * be put here. For example, user defined output committer can be configured here - * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. - */ - def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory - - /** - * Returns a [[OutputWriterFactory]] for generating output writers that can write data. - * This method is current used only by FileStreamSinkWriter to generate output writers that - * does not use output committers to write data. The OutputWriter generated by the returned - * [[OutputWriterFactory]] must implement the method `newWriter(path)`.. - */ - def buildWriter( - sqlContext: SQLContext, - dataSchema: StructType, - options: Map[String, String]): OutputWriterFactory = { - // TODO: Remove this default implementation when the other formats have been ported - throw new UnsupportedOperationException(s"buildWriter is not supported for $this") - } - - /** - * Returns whether this format support returning columnar batch or not. - * - * TODO: we should just have different traits for the different formats. - */ - def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = { - false - } - - /** - * Allow FileFormats to have a pluggable way to utilize pushed filters to eliminate partitions - * before execution. By default no pruning is performed and the original partitioning is - * preserved. - */ - def filterPartitions( - filters: Seq[Filter], - schema: StructType, - conf: Configuration, - allFiles: Seq[FileStatus], - root: Path, - partitions: Seq[Partition]): Seq[Partition] = { - partitions - } - - /** - * Returns whether a file with `path` could be splitted or not. - */ - def isSplitable( - sparkSession: SparkSession, - options: Map[String, String], - path: Path): Boolean = { - false - } - - /** - * Returns a function that can be used to read a single file in as an Iterator of InternalRow. - * - * @param dataSchema The global data schema. It can be either specified by the user, or - * reconciled/merged from all underlying data files. If any partition columns - * are contained in the files, they are preserved in this schema. - * @param partitionSchema The schema of the partition column row that will be present in each - * PartitionedFile. These columns should be appended to the rows that - * are produced by the iterator. - * @param requiredSchema The schema of the data that should be output for each row. This may be a - * subset of the columns that are present in the file if column pruning has - * occurred. - * @param filters A set of filters than can optionally be used to reduce the number of rows output - * @param options A set of string -> string configuration options. - * @return - */ - def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - // TODO: Remove this default implementation when the other formats have been ported - // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. - throw new UnsupportedOperationException(s"buildReader is not supported for $this") - } - - /** - * Exactly the same as [[buildReader]] except that the reader function returned by this method - * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]] - * returns. - */ - def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val dataReader = buildReader( - sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) - - new (PartitionedFile => Iterator[InternalRow]) with Serializable { - private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - - private val joinedRow = new JoinedRow() - - // Using lazy val to avoid serialization - private lazy val appendPartitionColumns = - GenerateUnsafeProjection.generate(fullSchema, fullSchema) - - override def apply(file: PartitionedFile): Iterator[InternalRow] = { - // Using local val to avoid per-row lazy val check (pre-mature optimization?...) - val converter = appendPartitionColumns - - // Note that we have to apply the converter even though `file.partitionValues` is empty. - // This is because the converter is also responsible for converting safe `InternalRow`s into - // `UnsafeRow`s. - dataReader(file).map { dataRow => - converter(joinedRow(dataRow, file.partitionValues)) - } - } - } - } - -} - -/** - * The base class file format that is based on text file. - */ -abstract class TextBasedFileFormat extends FileFormat { - private var codecFactory: CompressionCodecFactory = null - override def isSplitable( - sparkSession: SparkSession, - options: Map[String, String], - path: Path): Boolean = { - if (codecFactory == null) { - codecFactory = new CompressionCodecFactory( - sparkSession.sessionState.newHadoopConfWithOptions(options)) - } - val codec = codecFactory.getCodec(path) - codec == null || codec.isInstanceOf[SplittableCompressionCodec] - } -} - -/** - * A collection of data files from a partitioned relation, along with the partition values in the - * form of an [[InternalRow]]. - */ -case class Partition(values: InternalRow, files: Seq[FileStatus]) - -/** - * An interface for objects capable of enumerating the files that comprise a relation as well - * as the partitioning characteristics of those files. - */ -trait FileCatalog { - - /** Returns the list of input paths from which the catalog will get files. */ - def paths: Seq[Path] - - /** Returns the specification of the partitions inferred from the data. */ - def partitionSpec(): PartitionSpec - - /** - * Returns all valid files grouped into partitions when the data is partitioned. If the data is - * unpartitioned, this will return a single partition with no partition values. - * - * @param filters The filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. - */ - def listFiles(filters: Seq[Expression]): Seq[Partition] - - /** Returns all the valid files. */ - def allFiles(): Seq[FileStatus] - - /** Refresh the file listing */ - def refresh(): Unit -} - - -/** - * Helper methods for gathering metadata from HDFS. - */ -object HadoopFsRelation extends Logging { - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && - !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") - } - - /** - * Create a LocatedFileStatus using FileStatus and block locations. - */ - def createLocatedFileStatus(f: FileStatus, locations: Array[BlockLocation]): LocatedFileStatus = { - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), which is - // very slow on some file system (RawLocalFileSystem, which is launch a subprocess and parse the - // stdout). - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - - // We don't filter files/directories whose name start with "_" except "_temporary" here, as - // specific data sources may take advantages over them (e.g. Parquet _metadata and - // _common_metadata files). "_temporary" directories are explicitly ignored since failed - // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name - // start with "." are also ignored. - def listLeafFiles(fs: FileSystem, status: FileStatus, filter: PathFilter): Array[FileStatus] = { - logTrace(s"Listing ${status.getPath}") - val name = status.getPath.getName.toLowerCase - if (shouldFilterOut(name)) { - Array.empty[FileStatus] - } else { - val statuses = { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) - val stats = files ++ dirs.flatMap(dir => listLeafFiles(fs, dir, filter)) - if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats - } - // statuses do not have any dirs. - statuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) - } - } - } - - // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play - // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. - // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from - // executor side and reconstruct it on driver side. - case class FakeBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - case class FakeFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[FakeBlockLocation]) - - def listLeafFilesInParallel( - paths: Seq[Path], - hadoopConf: Configuration, - sparkSession: SparkSession): mutable.LinkedHashSet[FileStatus] = { - assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, 10000) - - val fakeStatuses = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { paths => - // Dummy jobconf to get to the pathFilter defined in configuration - // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) - val jobConf = new JobConf(serializableConfiguration.value, this.getClass) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - paths.map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(serializableConfiguration.value) - listLeafFiles(fs, fs.getFileStatus(path), pathFilter) - } - }.map { status => - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - FakeBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[FakeBlockLocation] - } - - FakeFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - }.collect() - - val hadoopFakeStatuses = fakeStatuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)), - blockLocations) - } - mutable.LinkedHashSet(hadoopFakeStatuses: _*) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e32db73bd6c6a..41edb6511c2ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -607,7 +607,7 @@ object JdbcUtils extends Logging { } catch { case e: SQLException => val cause = e.getNextException - if (e.getCause != cause) { + if (cause != null && e.getCause != cause) { if (e.getCause == null) { e.initCause(cause) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 9fe38ccc9fdc6..651fa78a4e924 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -82,11 +83,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context) + new JsonOutputWriter(stagingDir, parsedOptions, fileNamePrefix, dataSchema, context) } } } @@ -153,13 +154,18 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } private[json] class JsonOutputWriter( - path: String, + stagingDir: String, options: JSONOptions, - bucketId: Option[Int], + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".json" + compressionExtension).toString + } + private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) @@ -168,12 +174,7 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString.json$extension") + new Path(path) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f08936e31214d..83687d488c4d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -47,7 +47,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -127,7 +126,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { @@ -139,10 +138,10 @@ class ParquetFileFormat new OutputWriterFactory { override def newInstance( path: String, - bucketId: Option[Int], + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, bucketId, context) + new ParquetOutputWriter(path, fileNamePrefix, context) } } } @@ -488,150 +487,6 @@ class ParquetFileFormat } -/** - * A factory for generating OutputWriters for writing parquet files. This implemented is different - * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply - * writes the data to the path used to generate the output writer. Callers of this factory - * has to ensure which files are to be considered as committed. - */ -private[parquet] class ParquetOutputWriterFactory( - sqlConf: SQLConf, - dataSchema: StructType, - hadoopConf: Configuration, - options: Map[String, String]) extends OutputWriterFactory { - - private val serializableConf: SerializableConfiguration = { - val job = Job.getInstance(hadoopConf) - val conf = ContextUtil.getConfiguration(job) - val parquetOptions = new ParquetOptions(options, sqlConf) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushing down filters. - val dataSchemaToWrite = StructType.removeMetadata( - StructType.metadataKeyForOptionalField, - dataSchema).asInstanceOf[StructType] - ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) - - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlConf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp.toString) - - conf.set( - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - sqlConf.writeLegacyParquetFormat.toString) - - // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) - new SerializableConfiguration(conf) - } - - /** - * Returns a [[OutputWriter]] that writes data to the give path without using - * [[OutputCommitter]]. - */ - override def newWriter(path: String): OutputWriter = new OutputWriter { - - // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter - private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) - private val hadoopAttemptContext = new TaskAttemptContextImpl( - serializableConf.value, hadoopTaskAttemptId) - - // Instance of ParquetRecordWriter that does not use OutputCommitter - private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - - override def write(row: Row): Unit = { - throw new UnsupportedOperationException("call writeInternal") - } - - protected[sql] override def writeInternal(row: InternalRow): Unit = { - recordWriter.write(null, row) - } - - override def close(): Unit = recordWriter.close(hadoopAttemptContext) - } - - /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ - private def createNoCommitterRecordWriter( - path: String, - hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { - // Custom ParquetOutputFormat that disable use of committer and writes to the given path - val outputFormat = new ParquetOutputFormat[InternalRow]() { - override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } - override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } - } - outputFormat.getRecordWriter(hadoopAttemptContext) - } - - /** Disable the use of the older API. */ - def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - throw new UnsupportedOperationException( - "this version of newInstance not supported for " + - "ParquetOutputWriterFactory") - } -} - - -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[parquet] class ParquetOutputWriter( - path: String, - bucketId: Option[Int], - context: TaskAttemptContext) - extends OutputWriter { - - private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate unique output file names to avoid - // overwriting existing files (either exist before the write job, or are just written - // by other tasks within the same write job). - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - // It has the `.parquet` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "page" in Parquet format. - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) - - override def close(): Unit = recordWriter.close(context) -} - object ParquetFileFormat extends Logging { private[parquet] def readSchema( footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 615731889dfad..d0fd23605bea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -35,7 +35,7 @@ private[parquet] class ParquetOptions( * Compression codec to use. By default use the value specified in SQLConf. * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ - val compressionCodec: String = { + val compressionCodecClassName: String = { val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala new file mode 100644 index 0000000000000..1300069c42b05 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetRecordWriter} +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.util.ContextUtil + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + + +/** + * A factory for generating OutputWriters for writing parquet files. This implemented is different + * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply + * writes the data to the path used to generate the output writer. Callers of this factory + * has to ensure which files are to be considered as committed. + */ +private[parquet] class ParquetOutputWriterFactory( + sqlConf: SQLConf, + dataSchema: StructType, + hadoopConf: Configuration, + options: Map[String, String]) + extends OutputWriterFactory { + + private val serializableConf: SerializableConfiguration = { + val job = Job.getInstance(hadoopConf) + val conf = ContextUtil.getConfiguration(job) + val parquetOptions = new ParquetOptions(options, sqlConf) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushing down filters. + val dataSchemaToWrite = StructType.removeMetadata( + StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) + + // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) + // and `CatalystWriteSupport` (writing actual rows to Parquet files). + conf.set( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlConf.isParquetBinaryAsString.toString) + + conf.set( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp.toString) + + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlConf.writeLegacyParquetFormat.toString) + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) + new SerializableConfiguration(conf) + } + + /** + * Returns a [[OutputWriter]] that writes data to the give path without using + * [[OutputCommitter]]. + */ + override def newWriter(path1: String): OutputWriter = new OutputWriter { + + // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter + private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl( + serializableConf.value, hadoopTaskAttemptId) + + // Instance of ParquetRecordWriter that does not use OutputCommitter + private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) + + override def path: String = path1 + + override def write(row: Row): Unit = { + throw new UnsupportedOperationException("call writeInternal") + } + + protected[sql] override def writeInternal(row: InternalRow): Unit = { + recordWriter.write(null, row) + } + + override def close(): Unit = recordWriter.close(hadoopAttemptContext) + } + + /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ + private def createNoCommitterRecordWriter( + path: String, + hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { + // Custom ParquetOutputFormat that disable use of committer and writes to the given path + val outputFormat = new ParquetOutputFormat[InternalRow]() { + override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } + override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } + } + outputFormat.getRecordWriter(hadoopAttemptContext) + } + + /** Disable the use of the older API. */ + override def newInstance( + path: String, + fileNamePrefix: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + throw new UnsupportedOperationException("this version of newInstance not supported for " + + "ParquetOutputWriterFactory") + } +} + + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[parquet] class ParquetOutputWriter( + stagingDir: String, + fileNamePrefix: String, + context: TaskAttemptContext) + extends OutputWriter { + + override val path: String = { + val filename = fileNamePrefix + CodecConfig.from(context).getCodec.getExtension + ".parquet" + new Path(stagingDir, filename).toString + } + + private val recordWriter: RecordWriter[Void, InternalRow] = { + new ParquetOutputFormat[InternalRow]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..4dea8cf29ec58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -269,11 +269,15 @@ private[parquet] object ParquetReadSupport { */ private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + val parquetFieldMap = parquetRecord.getFields.asScala + .map(f => f.getName -> f).toMap + val caseInsensitiveParquetFieldMap = parquetRecord.getFields.asScala + .map(f => f.getName.toLowerCase -> f).toMap val toParquet = new ParquetSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) + .orElse(caseInsensitiveParquetFieldMap.get(f.name.toLowerCase)) .map(clipParquetType(_, f.dataType)) .getOrElse(toParquet.convertField(f)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index bd6eb6e0535ab..4647b11af4dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -187,8 +187,8 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl colName: String, colType: String): String = { val tableCols = schema.map(_.name) - val conf = sparkSession.sessionState.conf - tableCols.find(conf.resolver(_, colName)).getOrElse { + val resolver = sparkSession.sessionState.conf.resolver + tableCols.find(resolver(_, colName)).getOrElse { failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + s"defined table columns are: ${tableCols.mkString(", ")}") } @@ -209,50 +209,55 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { tblName: String, partColNames: Seq[String]): InsertIntoTable = { - val expectedColumns = insert.expectedColumns - if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) { + val normalizedPartSpec = PartitioningUtils.normalizePartitionSpec( + insert.partition, partColNames, tblName, conf.resolver) + + val expectedColumns = { + val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet + insert.table.output.filterNot(a => staticPartCols.contains(a.name)) + } + + if (expectedColumns.length != insert.child.schema.length) { throw new AnalysisException( s"Cannot insert into table $tblName because the number of columns are different: " + - s"need ${expectedColumns.get.length} columns, " + + s"need ${expectedColumns.length} columns, " + s"but query has ${insert.child.schema.length} columns.") } - if (insert.partition.nonEmpty) { - // the query's partitioning must match the table's partitioning - // this is set for queries like: insert into ... partition (one = "a", two = ) - val samePartitionColumns = - if (conf.caseSensitiveAnalysis) { - insert.partition.keySet == partColNames.toSet - } else { - insert.partition.keySet.map(_.toLowerCase) == partColNames.map(_.toLowerCase).toSet - } - if (!samePartitionColumns) { + if (normalizedPartSpec.nonEmpty) { + if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( s""" |Requested partitioning does not match the table $tblName: - |Requested partitions: ${insert.partition.keys.mkString(",")} + |Requested partitions: ${normalizedPartSpec.keys.mkString(",")} |Table partitions: ${partColNames.mkString(",")} """.stripMargin) } - expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + + castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) } else { - // All partition columns are dynamic because because the InsertIntoTable command does + // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + castAndRenameChildOutput(insert, expectedColumns) .copy(partition = partColNames.map(_ -> None).toMap) } } - // TODO: do we really need to rename? - def castAndRenameChildOutput( + private def castAndRenameChildOutput( insert: InsertIntoTable, expectedOutput: Seq[Attribute]): InsertIntoTable = { val newChildOutput = expectedOutput.zip(insert.child.output).map { case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) { + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { actual } else { - Alias(Cast(actual, expected.dataType), expected.name)() + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias(Cast(actual, expected.dataType), expected.name)( + explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9f96667311015..d40b5725199a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.datasources.text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.io.compress.GzipCodec import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -73,14 +75,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - if (bucketId.isDefined) { - throw new AnalysisException("Text doesn't support bucketing") - } - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) } } } @@ -124,19 +123,24 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } -class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) +class TextOutputWriter( + stagingDir: String, + fileNamePrefix: String, + dataSchema: StructType, + context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".txt" + compressionExtension).toString + } + private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") + new Path(path) } }.getRecordWriter(context) } @@ -153,3 +157,17 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp recordWriter.close(context) } } + + +object TextOutputWriter { + /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ + def getCompressionExtension(context: TaskAttemptContext): String = { + // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile + if (FileOutputFormat.getCompressOutput(context)) { + val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec]) + ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension + } else { + "" + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index d321f4cd76877..0395c43ba2cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** @@ -69,15 +69,6 @@ package object debug { output } - /** - * Augments [[SparkSession]] with debug methods. - */ - implicit class DebugSQLContext(sparkSession: SparkSession) { - def debug(): Unit = { - sparkSession.conf.set(SQLConf.DATAFRAME_EAGER_ANALYSIS.key, false) - } - } - /** * Augments [[Dataset]]s with debug methods. */ @@ -171,6 +162,8 @@ package object debug { } } + override def outputPartitioning: Partitioning = child.outputPartitioning + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a765..ce5013daeb1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -128,7 +128,8 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout) + .asInstanceOf[broadcast.Broadcast[T]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 86a8770715600..9918ac327f2dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.util.Utils - /** * Take the first `limit` elements and collect them to a single partition. * @@ -54,8 +53,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -95,14 +93,22 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning } /** * Take the first `limit` elements of the child's single output partition. */ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -122,8 +128,6 @@ case class TakeOrderedAndProjectExec( projectList.map(_.toAttribute) } - override def outputPartitioning: Partitioning = SinglePartition - override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) @@ -160,6 +164,8 @@ case class TakeOrderedAndProjectExec( override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputPartitioning: Partitioning = SinglePartition + override def simpleString: String = { val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") val outputString = Utils.truncatedString(output, "[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2acc5110e8950..9df56bbf1ef87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -68,6 +68,8 @@ case class DeserializeToObjectExec( outputObjAttr: Attribute, child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport { + override def outputPartitioning: Partitioning = child.outputPartitioning + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } @@ -102,6 +104,8 @@ case class SerializeFromObjectExec( override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def outputPartitioning: Partitioning = child.outputPartitioning + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } @@ -171,6 +175,8 @@ case class MapPartitionsExec( child: SparkPlan) extends ObjectConsumerExec with ObjectProducerExec { + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) @@ -231,6 +237,8 @@ case class MapElementsExec( } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning } /** @@ -244,6 +252,8 @@ case class AppendColumnsExec( override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute) + override def outputPartitioning: Partitioning = child.outputPartitioning + private def newColumnSchema = serializer.map(_.toAttribute).toStructType override protected def doExecute(): RDD[InternalRow] = { @@ -272,6 +282,8 @@ case class AppendColumnsWithObjectExec( override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute) + override def outputPartitioning: Partitioning = child.outputPartitioning + private def inputSchema = inputSerializer.map(_.toAttribute).toStructType private def newColumnSchema = newColumnsSerializer.map(_.toAttribute).toStructType @@ -304,6 +316,8 @@ case class MapGroupsExec( outputObjAttr: Attribute, child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { + override def outputPartitioning: Partitioning = child.outputPartitioning + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -347,6 +361,9 @@ case class FlatMapGroupsInRExec( child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { override def output: Seq[Attribute] = outputObjAttr :: Nil + + override def outputPartitioning: Partitioning = child.outputPartitioning + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 724025b4647f4..46fd54e5c7420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -124,11 +124,11 @@ object EvaluatePython { case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keyValues = c.asScala.toSeq - val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray - val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray - ArrayBasedMapData(keys, values) + case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => + ArrayBasedMapData( + javaMap, + (key: Any) => fromJava(key, keyType), + (value: Any) => fromJava(value, valueType)) case (c, StructType(fields)) if c.getClass.isArray => val array = c.asInstanceOf[Array[_]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index 422a3f862d96f..cd1e77f524afd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -22,6 +22,7 @@ import java.io._ import com.google.common.io.Closeables import org.apache.spark.SparkException +import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -130,7 +131,7 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu if (out != null) { out.close() out = null - in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString))) + in = new DataInputStream(new NioBufferedFileInputStream(file)) } if (unreadBytes > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index c14feea91ed7d..b26edeeb04009 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -146,7 +146,7 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( */ def allFiles(): Array[T] = { var latestId = getLatest().map(_._1).getOrElse(-1L) - // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` + // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileIndex` // is calling this method. This loop will retry the reading to deal with the // race condition. while (true) { @@ -158,7 +158,7 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( } catch { case e: IOException => // Another process using `CompactibleFileStreamLog` may delete the batch files when - // `StreamFileCatalog` are reading. However, it only happens when a compaction is + // `StreamFileIndex` are reading. However, it only happens when a compaction is // deleting old files. If so, let's try the next compaction batch and we should find it. // Otherwise, this is a real IO issue and we should throw it. latestId = nextCompactionBatchId(latestId, compactInterval) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 614a6261e7c28..680df01acc1a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.execution.datasources.{DataSource, ListingFileCatalog, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.types.StructType /** @@ -35,6 +35,7 @@ class FileStreamSource( path: String, fileFormatClassName: String, override val schema: StructType, + partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { @@ -142,6 +143,7 @@ class FileStreamSource( sparkSession, paths = files.map(_.path), userSpecifiedSchema = Some(schema), + partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( @@ -154,7 +156,7 @@ class FileStreamSource( private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) - val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType)) + val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => (status.getPath.toUri.toString, status.getModificationTime) } @@ -174,6 +176,15 @@ class FileStreamSource( override def toString: String = s"FileStreamSource[$qualifiedBasePath]" + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + override def commit(end: Offset): Unit = { + // No-op for now; FileStreamSource currently garbage-collects files based on timestamp + // and the value of the maxFileAge parameter. + } + override def stop() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 082664aa23f04..24f98b9211f12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -68,19 +68,16 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria } datasetWithIncrementalExecution.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { - var isFailed = false try { while (iter.hasNext) { writer.process(iter.next()) } } catch { case e: Throwable => - isFailed = true writer.close(e) + throw e } - if (!isFailed) { - writer.close(null) - } + writer.close(null) } else { writer.close(null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala index a32c4671e3475..aeaa134736937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.execution.datasources._ /** - * A [[FileCatalog]] that generates the list of files to processing by reading them from the + * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. */ -class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) - extends PartitioningAwareFileCatalog(sparkSession, Map.empty, None) { +class MetadataLogFileIndex(sparkSession: SparkSession, path: Path) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -47,7 +47,7 @@ class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) allFilesFromLog.toArray.groupBy(_.getPath.getParent) } - override def paths: Seq[Path] = path :: Nil + override def rootPaths: Seq[Path] = path :: Nil override def refresh(): Unit = { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 971147840d2fd..f3bd5bfe23fdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -30,16 +30,30 @@ trait Source { /** Returns the schema of the data from this source */ def schema: StructType - /** Returns the maximum available offset for this source. */ + /** + * Returns the maximum available offset for this source. + * Returns `None` if this source has never received any data. + */ def getOffset: Option[Offset] /** - * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None` then - * the batch should begin with the first available record. This method must always return the - * same data for a particular `start` and `end` pair. + * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None`, + * then the batch should begin with the first record. This method must always return the + * same data for a particular `start` and `end` pair; even after the Source has been restarted + * on a different node. + * + * Higher layers will always call this method with a value of `start` greater than or equal + * to the last value passed to `commit` and a value of `end` less than or equal to the + * last value returned by `getOffset` */ def getBatch(start: Option[Offset], end: Offset): DataFrame + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + def commit(end: Offset) : Unit = {} + /** Stop this source and free any resources it has allocated. */ def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 4d0283fbef1d0..ad8238f189c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -22,7 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan @@ -56,7 +58,12 @@ case class StateStoreRestoreExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -69,11 +76,15 @@ case class StateStoreRestoreExec( iter.flatMap { row => val key = getKey(row) val savedState = store.get(key) + numOutputRows += 1 row +: savedState.toSeq } } } + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning } /** @@ -86,7 +97,13 @@ case class StateStoreSaveExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver assert(returnAllStates.nonEmpty, "Incorrect planning in IncrementalExecution, returnAllStates have not been set") val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ @@ -103,6 +120,8 @@ case class StateStoreSaveExec( override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + /** * Save all the rows to the state store, and return all the rows in the state store. * Note that this returns an iterator that pipelines the saving to store with downstream @@ -111,6 +130,10 @@ case class StateStoreSaveExec( private def saveAndReturnUpdated( store: StateStore, iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + new Iterator[InternalRow] { private[this] val baseIterator = iter private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -118,6 +141,7 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { store.commit() + numTotalStateRows += store.numKeys() false } else { true @@ -128,6 +152,8 @@ case class StateStoreSaveExec( val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 row } } @@ -142,12 +168,21 @@ case class StateStoreSaveExec( store: StateStore, iter: Iterator[InternalRow]): Iterator[InternalRow] = { val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val key = getKey(row) store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 } store.commit() - store.iterator().map(_._2.asInstanceOf[InternalRow]) + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 333239f875bd3..37af1a550aaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -57,6 +57,7 @@ class StreamExecution( extends StreamingQuery with Logging { import org.apache.spark.sql.streaming.StreamingQueryListener._ + import StreamMetrics._ private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay @@ -72,6 +73,9 @@ class StreamExecution( /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile var committedOffsets = new StreamProgress @@ -79,6 +83,9 @@ class StreamExecution( /** * Tracks the offsets that are available to be processed, but have not yet be committed to the * sink. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile private var availableOffsets = new StreamProgress @@ -105,11 +112,22 @@ class StreamExecution( var lastExecution: QueryExecution = null @volatile - var streamDeathCause: StreamingQueryException = null + private var streamDeathCause: StreamingQueryException = null /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() + /** Metrics for this query */ + private val streamMetrics = + new StreamMetrics(uniqueSources.toSet, triggerClock, s"StructuredStreaming.$name") + + @volatile + private var currentStatus: StreamingQueryStatus = null + + /** Flag that signals whether any error with input metrics have already been logged */ + @volatile + private var metricWarningLogged: Boolean = false + /** * The thread that runs the micro-batches of this stream. Note that this thread must be * [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using @@ -136,16 +154,14 @@ class StreamExecution( /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE + /** Returns the current status of the query. */ + override def status: StreamingQueryStatus = currentStatus + /** Returns current status of all the sources. */ - override def sourceStatuses: Array[SourceStatus] = { - val localAvailableOffsets = availableOffsets - sources.map(s => - new SourceStatus(s.toString, localAvailableOffsets.get(s).map(_.toString))).toArray - } + override def sourceStatuses: Array[SourceStatus] = currentStatus.sourceStatuses.toArray /** Returns current status of the sink. */ - override def sinkStatus: SinkStatus = - new SinkStatus(sink.toString, committedOffsets.toCompositeOffset(sources).toString) + override def sinkStatus: SinkStatus = currentStatus.sinkStatus /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */ override def exception: Option[StreamingQueryException] = Option(streamDeathCause) @@ -155,7 +171,7 @@ class StreamExecution( new Path(new Path(checkpointRoot), name).toUri.toString /** - * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event + * Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]] * has been posted to all the listeners. */ def start(): Unit = { @@ -167,16 +183,21 @@ class StreamExecution( /** * Repeatedly attempts to run batches as data arrives. * - * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted - * such that listeners are guaranteed to get a start event before a termination. Furthermore, this - * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. + * Note that this method ensures that [[QueryStartedEvent]] and [[QueryTerminatedEvent]] are + * posted such that listeners are guaranteed to get a start event before a termination. + * Furthermore, this method also ensures that [[QueryStartedEvent]] event is posted before the + * `start()` method returns. */ private def runBatches(): Unit = { try { // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, // so must mark this as ACTIVE first. state = ACTIVE - postEvent(new QueryStarted(this.toInfo)) // Assumption: Does not throw exception. + if (sparkSession.sessionState.conf.streamingMetricsEnabled) { + sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) + } + updateStatus() + postEvent(new QueryStartedEvent(currentStatus)) // Assumption: Does not throw exception. // Unblock starting thread startLatch.countDown() @@ -185,25 +206,41 @@ class StreamExecution( SparkSession.setActiveSession(sparkSession) triggerExecutor.execute(() => { - if (isActive) { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets() - logDebug(s"Stream running from $committedOffsets to $availableOffsets") + streamMetrics.reportTriggerStarted(currentBatchId) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Finding new data from sources") + updateStatus() + val isTerminated = reportTimeTaken(TRIGGER_LATENCY) { + if (isActive) { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets() + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() + } + if (dataAvailable) { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, true) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Processing new data") + updateStatus() + runBatch() + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + } else { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, false) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "No new data") + updateStatus() + Thread.sleep(pollingDelayMs) + } + true } else { - constructNextBatch() + false } - if (dataAvailable) { - runBatch() - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - Thread.sleep(pollingDelayMs) - } - true - } else { - false } + // Update metrics and notify others + streamMetrics.reportTriggerFinished() + updateStatus() + postEvent(new QueryProgressEvent(currentStatus)) + isTerminated }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() @@ -221,8 +258,16 @@ class StreamExecution( } } finally { state = TERMINATED + + // Update metrics and status + streamMetrics.stop() + sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics) + updateStatus() + + // Notify others sparkSession.streams.notifyQueryTermination(StreamExecution.this) - postEvent(new QueryTerminated(this.toInfo, exception.map(_.cause).map(Utils.exceptionString))) + postEvent( + new QueryTerminatedEvent(currentStatus, exception.map(_.cause).map(Utils.exceptionString))) terminationLatch.countDown() } } @@ -248,7 +293,6 @@ class StreamExecution( committedOffsets = lastOffsets.toStreamProgress(sources) logDebug(s"Resuming with committed offsets: $committedOffsets") } - case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 @@ -278,8 +322,14 @@ class StreamExecution( val hasNewData = { awaitBatchLock.lock() try { - val newData = uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) - availableOffsets ++= newData + reportTimeTaken(GET_OFFSET_LATENCY) { + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => + reportTimeTaken(s, SOURCE_GET_OFFSET_LATENCY) { + (s, s.getOffset) + } + }.toMap + availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + } if (dataAvailable) { true @@ -292,16 +342,29 @@ class StreamExecution( } } if (hasNewData) { - assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), - s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") - logInfo(s"Committed offsets for batch $currentBatchId.") - - // Now that we have logged the new batch, no further processing will happen for - // the previous batch, and it is safe to discard the old metadata. - // Note that purge is exclusive, i.e. it purges everything before currentBatchId. - // NOTE: If StreamExecution implements pipeline parallelism (multiple batches in - // flight at the same time), this cleanup logic will need to change. - offsetLog.purge(currentBatchId) + reportTimeTaken(OFFSET_WAL_WRITE_LATENCY) { + assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + logInfo(s"Committed offsets for batch $currentBatchId.") + + // NOTE: The following code is correct because runBatches() processes exactly one + // batch at a time. If we add pipeline parallelism (multiple batches in flight at + // the same time), this cleanup logic will need to change. + + // Now that we've updated the scheduler's persistent checkpoint, it is safe for the + // sources to discard data from the previous batch. + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } + + // Now that we have logged the new batch, no further processing will happen for + // the batch before the previous batch, and it is safe to discard the old metadata. + // Note that purge is exclusive, i.e. it purges everything before the target ID. + offsetLog.purge(currentBatchId - 1) + } } else { awaitBatchLock.lock() try { @@ -311,26 +374,30 @@ class StreamExecution( awaitBatchLock.unlock() } } + reportTimestamp(GET_OFFSET_TIMESTAMP) } /** * Processes any data available between `availableOffsets` and `committedOffsets`. */ private def runBatch(): Unit = { - val startTime = System.nanoTime() - // TODO: Move this to IncrementalExecution. // Request unprocessed data from all sources. - val newData = availableOffsets.flatMap { - case (source, available) + val newData = reportTimeTaken(GET_BATCH_LATENCY) { + availableOffsets.flatMap { + case (source, available) if committedOffsets.get(source).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(source) - val batch = source.getBatch(current, available) - logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) - case _ => None - }.toMap + val current = committedOffsets.get(source) + val batch = reportTimeTaken(source, SOURCE_GET_BATCH_LATENCY) { + source.getBatch(current, available) + } + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + } + } + reportTimestamp(GET_BATCH_TIMESTAMP) // A list of attributes that will need to be updated. var replacements = new ArrayBuffer[(Attribute, Attribute)] @@ -351,25 +418,24 @@ class StreamExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val newPlan = withNewSources transformAllExpressions { + val triggerLogicalPlan = withNewSources transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a) } - val optimizerStart = System.nanoTime() - lastExecution = new IncrementalExecution( - sparkSession, - newPlan, - outputMode, - checkpointFile("state"), - currentBatchId) - - lastExecution.executedPlan - val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 - logDebug(s"Optimized batch in ${optimizerTime}ms") + val executedPlan = reportTimeTaken(OPTIMIZER_LATENCY) { + lastExecution = new IncrementalExecution( + sparkSession, + triggerLogicalPlan, + outputMode, + checkpointFile("state"), + currentBatchId) + lastExecution.executedPlan // Force the lazy generation of execution plan + } val nextBatch = new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId, nextBatch) + reportNumRows(executedPlan, triggerLogicalPlan, newData) awaitBatchLock.lock() try { @@ -379,11 +445,8 @@ class StreamExecution( awaitBatchLock.unlock() } - val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") // Update committed offsets. committedOffsets ++= availableOffsets - postEvent(new QueryProgress(this.toInfo)) } private def postEvent(event: StreamingQueryListener.Event) { @@ -408,7 +471,7 @@ class StreamExecution( /** * Blocks the current thread until processing for data from the given `source` has reached at - * least the given `Offset`. This method is indented for use primarily when writing tests. + * least the given `Offset`. This method is intended for use primarily when writing tests. */ private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { def notDone = { @@ -516,12 +579,131 @@ class StreamExecution( """.stripMargin } - private def toInfo: StreamingQueryInfo = { - new StreamingQueryInfo( - this.name, - this.id, - this.sourceStatuses, - this.sinkStatus) + /** + * Report row metrics of the executed trigger + * @param triggerExecutionPlan Execution plan of the trigger + * @param triggerLogicalPlan Logical plan of the trigger, generated from the query logical plan + * @param sourceToDF Source to DataFrame returned by the source.getBatch in this trigger + */ + private def reportNumRows( + triggerExecutionPlan: SparkPlan, + triggerLogicalPlan: LogicalPlan, + sourceToDF: Map[Source, DataFrame]): Unit = { + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = sourceToDF.flatMap { case (source, df) => + df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = triggerLogicalPlan.collectLeaves() // includes non-streaming sources + val allExecPlanLeaves = triggerExecutionPlan.collectLeaves() + val sourceToNumInputRows: Map[Source, Long] = + if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { + val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { + case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } + } + val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + source -> numRows + } + sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + } else { + if (!metricWarningLogged) { + def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( + "Could not report metrics as number leaves in trigger logical plan did not match that" + + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + metricWarningLogged = true + } + Map.empty + } + val numOutputRows = triggerExecutionPlan.metrics.get("numOutputRows").map(_.value) + val stateNodes = triggerExecutionPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + + streamMetrics.reportNumInputRows(sourceToNumInputRows) + stateNodes.zipWithIndex.foreach { case (s, i) => + streamMetrics.reportTriggerDetail( + NUM_TOTAL_STATE_ROWS(i + 1), + s.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L)) + streamMetrics.reportTriggerDetail( + NUM_UPDATED_STATE_ROWS(i + 1), + s.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)) + } + updateStatus() + } + + private def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + val timeTaken = math.max(endTime - startTime, 0) + streamMetrics.reportTriggerDetail(triggerDetailKey, timeTaken) + updateStatus() + if (triggerDetailKey == TRIGGER_LATENCY) { + logInfo(s"Completed up to $availableOffsets in $timeTaken ms") + } + result + } + + private def reportTimeTaken[T](source: Source, triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + streamMetrics.reportSourceTriggerDetail( + source, triggerDetailKey, math.max(endTime - startTime, 0)) + updateStatus() + result + } + + private def reportTimestamp(triggerDetailKey: String): Unit = { + streamMetrics.reportTriggerDetail(triggerDetailKey, triggerClock.getTimeMillis) + updateStatus() + } + + private def updateStatus(): Unit = { + val localAvailableOffsets = availableOffsets + val sourceStatuses = sources.map { s => + SourceStatus( + s.toString, + localAvailableOffsets.get(s).map(_.toString).getOrElse("-"), // TODO: use json if available + streamMetrics.currentSourceInputRate(s), + streamMetrics.currentSourceProcessingRate(s), + streamMetrics.currentSourceTriggerDetails(s)) + }.toArray + val sinkStatus = SinkStatus( + sink.toString, + committedOffsets.toCompositeOffset(sources).toString) + + currentStatus = + StreamingQueryStatus( + name = name, + id = id, + timestamp = triggerClock.getTimeMillis(), + inputRate = streamMetrics.currentInputRate(), + processingRate = streamMetrics.currentProcessingRate(), + latency = streamMetrics.currentLatency(), + sourceStatuses = sourceStatuses, + sinkStatus = sinkStatus, + triggerDetails = streamMetrics.currentTriggerDetails()) } trait State diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala new file mode 100644 index 0000000000000..e98d1883e4596 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.{util => ju} + +import scala.collection.mutable + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.util.Clock + +/** + * Class that manages all the metrics related to a StreamingQuery. It does the following. + * - Calculates metrics (rates, latencies, etc.) based on information reported by StreamExecution. + * - Allows the current metric values to be queried + * - Serves some of the metrics through Codahale/DropWizard metrics + * + * @param sources Unique set of sources in a query + * @param triggerClock Clock used for triggering in StreamExecution + * @param codahaleSourceName Root name for all the Codahale metrics + */ +class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceName: String) + extends CodahaleSource with Logging { + + import StreamMetrics._ + + // Trigger infos + private val triggerDetails = new mutable.HashMap[String, String] + private val sourceTriggerDetails = new mutable.HashMap[Source, mutable.HashMap[String, String]] + + // Rate estimators for sources and sinks + private val inputRates = new mutable.HashMap[Source, RateCalculator] + private val processingRates = new mutable.HashMap[Source, RateCalculator] + + // Number of input rows in the current trigger + private val numInputRows = new mutable.HashMap[Source, Long] + private var currentTriggerStartTimestamp: Long = -1 + private var previousTriggerStartTimestamp: Long = -1 + private var latency: Option[Double] = None + + override val sourceName: String = codahaleSourceName + override val metricRegistry: MetricRegistry = new MetricRegistry + + // =========== Initialization =========== + + // Metric names should not have . in them, so that all the metrics of a query are identified + // together in Ganglia as a single metric group + registerGauge("inputRate-total", currentInputRate) + registerGauge("processingRate-total", () => currentProcessingRate) + registerGauge("latency", () => currentLatency().getOrElse(-1.0)) + + sources.foreach { s => + inputRates.put(s, new RateCalculator) + processingRates.put(s, new RateCalculator) + sourceTriggerDetails.put(s, new mutable.HashMap[String, String]) + + registerGauge(s"inputRate-${s.toString}", () => currentSourceInputRate(s)) + registerGauge(s"processingRate-${s.toString}", () => currentSourceProcessingRate(s)) + } + + // =========== Setter methods =========== + + def reportTriggerStarted(triggerId: Long): Unit = synchronized { + numInputRows.clear() + triggerDetails.clear() + sourceTriggerDetails.values.foreach(_.clear()) + + reportTriggerDetail(TRIGGER_ID, triggerId) + sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(IS_TRIGGER_ACTIVE, true) + currentTriggerStartTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) + } + + def reportTriggerDetail[T](key: String, value: T): Unit = synchronized { + triggerDetails.put(key, value.toString) + } + + def reportSourceTriggerDetail[T](source: Source, key: String, value: T): Unit = synchronized { + sourceTriggerDetails(source).put(key, value.toString) + } + + def reportNumInputRows(inputRows: Map[Source, Long]): Unit = synchronized { + numInputRows ++= inputRows + } + + def reportTriggerFinished(): Unit = synchronized { + require(currentTriggerStartTimestamp >= 0) + val currentTriggerFinishTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(FINISH_TIMESTAMP, currentTriggerFinishTimestamp) + triggerDetails.remove(STATUS_MESSAGE) + reportTriggerDetail(IS_TRIGGER_ACTIVE, false) + + // Report number of rows + val totalNumInputRows = numInputRows.values.sum + reportTriggerDetail(NUM_INPUT_ROWS, totalNumInputRows) + numInputRows.foreach { case (s, r) => + reportSourceTriggerDetail(s, NUM_SOURCE_INPUT_ROWS, r) + } + + val currentTriggerDuration = currentTriggerFinishTimestamp - currentTriggerStartTimestamp + val previousInputIntervalOption = if (previousTriggerStartTimestamp >= 0) { + Some(currentTriggerStartTimestamp - previousTriggerStartTimestamp) + } else None + + // Update input rate = num rows received by each source during the previous trigger interval + // Interval is measures as interval between start times of previous and current trigger. + // + // TODO: Instead of trigger start, we should use time when getOffset was called on each source + // as this may be different for each source if there are many sources in the query plan + // and getOffset is called serially on them. + if (previousInputIntervalOption.nonEmpty) { + sources.foreach { s => + inputRates(s).update(numInputRows.getOrElse(s, 0), previousInputIntervalOption.get) + } + } + + // Update processing rate = num rows processed for each source in current trigger duration + sources.foreach { s => + processingRates(s).update(numInputRows.getOrElse(s, 0), currentTriggerDuration) + } + + // Update latency = if data present, 0.5 * previous trigger interval + current trigger duration + if (previousInputIntervalOption.nonEmpty && totalNumInputRows > 0) { + latency = Some((previousInputIntervalOption.get.toDouble / 2) + currentTriggerDuration) + } else { + latency = None + } + + previousTriggerStartTimestamp = currentTriggerStartTimestamp + currentTriggerStartTimestamp = -1 + } + + // =========== Getter methods =========== + + def currentInputRate(): Double = synchronized { + // Since we are calculating source input rates using the same time interval for all sources + // it is fine to calculate total input rate as the sum of per source input rate. + inputRates.map(_._2.currentRate).sum + } + + def currentSourceInputRate(source: Source): Double = synchronized { + inputRates(source).currentRate + } + + def currentProcessingRate(): Double = synchronized { + // Since we are calculating source processing rates using the same time interval for all sources + // it is fine to calculate total processing rate as the sum of per source processing rate. + processingRates.map(_._2.currentRate).sum + } + + def currentSourceProcessingRate(source: Source): Double = synchronized { + processingRates(source).currentRate + } + + def currentLatency(): Option[Double] = synchronized { latency } + + def currentTriggerDetails(): Map[String, String] = synchronized { triggerDetails.toMap } + + def currentSourceTriggerDetails(source: Source): Map[String, String] = synchronized { + sourceTriggerDetails(source).toMap + } + + // =========== Other methods =========== + + private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + synchronized { + metricRegistry.register(name, new Gauge[T] { + override def getValue: T = f() + }) + } + } + + def stop(): Unit = synchronized { + triggerDetails.clear() + inputRates.valuesIterator.foreach { _.stop() } + processingRates.valuesIterator.foreach { _.stop() } + latency = None + } +} + +object StreamMetrics extends Logging { + /** Simple utility class to calculate rate while avoiding DivideByZero */ + class RateCalculator { + @volatile private var rate: Option[Double] = None + + def update(numRows: Long, timeGapMs: Long): Unit = { + if (timeGapMs > 0) { + rate = Some(numRows.toDouble * 1000 / timeGapMs) + } else { + rate = None + logDebug(s"Rate updates cannot with zero or negative time gap $timeGapMs") + } + } + + def currentRate: Double = rate.getOrElse(0.0) + + def stop(): Unit = { rate = None } + } + + + val TRIGGER_ID = "triggerId" + val IS_TRIGGER_ACTIVE = "isTriggerActive" + val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" + val STATUS_MESSAGE = "statusMessage" + + val START_TIMESTAMP = "timestamp.triggerStart" + val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" + val GET_BATCH_TIMESTAMP = "timestamp.afterGetBatch" + val FINISH_TIMESTAMP = "timestamp.triggerFinish" + + val GET_OFFSET_LATENCY = "latency.getOffset.total" + val GET_BATCH_LATENCY = "latency.getBatch.total" + val OFFSET_WAL_WRITE_LATENCY = "latency.offsetLogWrite" + val OPTIMIZER_LATENCY = "latency.optimizer" + val TRIGGER_LATENCY = "latency.fullTrigger" + val SOURCE_GET_OFFSET_LATENCY = "latency.getOffset.source" + val SOURCE_GET_BATCH_LATENCY = "latency.getBatch.source" + + val NUM_INPUT_ROWS = "numRows.input.total" + val NUM_SOURCE_INPUT_ROWS = "numRows.input.source" + def NUM_TOTAL_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.total" + def NUM_UPDATED_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.updated" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 1e663956f980b..fc2190d39da4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) */ def post(event: StreamingQueryListener.Event) { event match { - case s: QueryStarted => + case s: QueryStartedEvent => postToAll(s) case _ => sparkListenerBus.post(event) @@ -59,11 +59,11 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) listener: StreamingQueryListener, event: StreamingQueryListener.Event): Unit = { event match { - case queryStarted: QueryStarted => + case queryStarted: QueryStartedEvent => listener.onQueryStarted(queryStarted) - case queryProgress: QueryProgress => + case queryProgress: QueryProgressEvent => listener.onQueryProgress(queryProgress) - case queryTerminated: QueryTerminated => + case queryTerminated: QueryTerminatedEvent => listener.onQueryTerminated(queryTerminated) case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 5052c4d50c5ed..48d9791faf1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging @@ -51,12 +51,23 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ @GuardedBy("this") - protected val batches = new ArrayBuffer[Dataset[A]] + protected val batches = new ListBuffer[Dataset[A]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + /** + * Last offset that was discarded, or -1 if no commits have occurred. Note that the value + * -1 is used in calculations below and isn't just an arbitrary constant. + */ + @GuardedBy("this") + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { @@ -85,21 +96,25 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" override def getOffset: Option[Offset] = synchronized { - if (batches.isEmpty) { + if (currentOffset.offset == -1) { None } else { Some(currentOffset) } } - /** - * Returns the data that is between the offsets (`start`, `end`]. - */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -111,7 +126,30 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } + override def commit(end: Offset): Unit = synchronized { + end match { + case newOffset: LongOffset => + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset + case _ => + sys.error(s"MemoryStream.commit() received an offset ($end) that did not originate with " + + "an instance of this class") + } + } + override def stop() {} + + def reset(): Unit = synchronized { + batches.clear() + currentOffset = new LongOffset(-1) + lastOffsetCommitted = new LongOffset(-1) + } } /** @@ -165,6 +203,8 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi logDebug(s"Skipping already committed batch: $batchId") } } + + override def toString(): String = "MemorySink" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index fb15239f9af98..c662e7c6bc775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -24,14 +24,15 @@ import java.text.SimpleDateFormat import java.util.Calendar import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + object TextSocketSource { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -53,8 +54,18 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo @GuardedBy("this") private var readThread: Thread = null + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ + @GuardedBy("this") + protected val batches = new ListBuffer[(String, Timestamp)] + + @GuardedBy("this") + protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") - private var lines = new ArrayBuffer[(String, Timestamp)] + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) initialize() @@ -74,10 +85,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo return } TextSocketSource.this.synchronized { - lines += ((line, + val newData = (line, Timestamp.valueOf( TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - )) + ) + currentOffset = currentOffset + 1 + batches.append(newData) } } } catch { @@ -92,21 +105,54 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else TextSocketSource.SCHEMA_REGULAR - /** Returns the maximum available offset for this source. */ override def getOffset: Option[Offset] = synchronized { - if (lines.isEmpty) None else Some(LongOffset(lines.size - 1)) + if (currentOffset.offset == -1) { + None + } else { + Some(currentOffset) + } } /** Returns the data that is between the offsets (`start`, `end`]. */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startIdx = start.map(_.asInstanceOf[LongOffset].offset.toInt + 1).getOrElse(0) - val endIdx = end.asInstanceOf[LongOffset].offset.toInt + 1 - val data = synchronized { lines.slice(startIdx, endIdx) } + val startOrdinal = + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + + // Internal buffer only holds the batches after lastOffsetCommitted + val rawList = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } + import sqlContext.implicits._ + val rawBatch = sqlContext.createDataset(rawList) + + // Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp + // if requested. if (includeTimestamp) { - data.toDF("value", "timestamp") + rawBatch.toDF("value", "timestamp") + } else { + // Strip out timestamp + rawBatch.select("_1").toDF("value") + } + } + + override def commit(end: Offset): Unit = synchronized { + if (end.isInstanceOf[LongOffset]) { + val newOffset = end.asInstanceOf[LongOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset } else { - data.map(_._1).toDF("value") + sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " + + s"originate with an instance of this class") } } @@ -141,7 +187,7 @@ class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegis providerName: String, parameters: Map[String, String]): (String, StructType) = { logWarning("The socket source should not be used for production applications! " + - "It does not support recovery and stores state indefinitely.") + "It does not support recovery.") if (!parameters.contains("host")) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index bec966b15ed0f..f1e7f1d113ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -159,7 +159,7 @@ private[state] class HDFSBackedStateStoreProvider( } catch { case NonFatal(e) => throw new IllegalStateException( - s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + s"Error committing version $newVersion into $this", e) } } @@ -197,12 +197,18 @@ private[state] class HDFSBackedStateStoreProvider( allUpdates.values().asScala.toIterator } + override def numKeys(): Long = mapToUpdate.size() + /** * Whether all updates have been committed */ override private[state] def hasCommitted: Boolean = { state == COMMITTED } + + override def toString(): String = { + s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + } } /** Get the state store for making updates to create a new `version` of the store. */ @@ -213,7 +219,7 @@ private[state] class HDFSBackedStateStoreProvider( newMap.putAll(loadMap(version)) } val store = new HDFSBackedStateStore(version, newMap) - logInfo(s"Retrieved version $version of $this for update") + logInfo(s"Retrieved version $version of ${HDFSBackedStateStoreProvider.this} for update") store } @@ -229,7 +235,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } /* Internal classes and methods */ @@ -491,10 +497,12 @@ private[state] class HDFSBackedStateStoreProvider( val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq mapsToRemove.foreach(loadedMaps.remove) } - files.filter(_.version < earliestFileToRetain.version).foreach { f => + val filesToDelete = files.filter(_.version < earliestFileToRetain.version) + filesToDelete.foreach { f => fs.delete(f.path, true) } - logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + + filesToDelete.mkString(", ")) } } } catch { @@ -558,7 +566,7 @@ private[state] class HDFSBackedStateStoreProvider( } } val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) - logDebug(s"Current set of files for $this: $storeFiles") + logDebug(s"Current set of files for $this: ${storeFiles.mkString(", ")}") storeFiles } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a67fdceb3cee6..7132e284c28f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -77,6 +77,9 @@ trait StateStore { */ def updates(): Iterator[StoreUpdate] + /** Number of keys in the state store */ + def numKeys(): Long + /** * Whether all updates have been committed */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d945d7aff2da4..267d17623d5e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -38,7 +38,7 @@ private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: Str private case class GetLocation(storeId: StateStoreId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(storeRootLocation: String) +private case class DeactivateInstances(checkpointLocation: String) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -111,11 +111,13 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => + logDebug(s"Reported state store $id is active at $executorId") instances.put(id, ExecutorCacheTaskLocation(host, executorId)) } @@ -125,19 +127,25 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadS case Some(location) => location.executorId == execId case None => false } + logDebug(s"Verified that state store $id is active: $response") context.reply(response) case GetLocation(id) => - context.reply(instances.get(id).map(_.toString)) + val executorId = instances.get(id).map(_.toString) + logDebug(s"Got location of the state store $id: $executorId") + context.reply(executorId) - case DeactivateInstances(loc) => + case DeactivateInstances(checkpointLocation) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == loc).toSeq + instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq instances --= storeIdsToRemove + logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + storeIdsToRemove.mkString(", ")) context.reply(true) case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered + logInfo("StateStoreCoordinator stopped") context.reply(true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 1dd281ebf1034..80b87d5ffa797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -103,6 +103,8 @@ case class WindowExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputPartitioning: Partitioning = child.outputPartitioning + /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 51179a528c503..eea98414003ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} @@ -51,6 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 49fdec57558e8..28598af781653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.Column import org.apache.spark.sql.functions @@ -39,12 +39,17 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -@Experimental +@InterfaceStability.Stable case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, inputTypes: Option[Seq[DataType]]) { + /** + * Returns an expression that invokes the UDF, using the given arguments. + * + * @since 1.3.0 + */ def apply(exprs: Column*): Column = { Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 3c1f6e897ea62..0b26d863cac5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ /** - * :: Experimental :: * Utility functions for defining window in DataFrames. * * {{{ @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@Experimental +@InterfaceStability.Stable object Window { /** @@ -163,7 +162,6 @@ object Window { } /** - * :: Experimental :: * Utility functions for defining window in DataFrames. * * {{{ @@ -176,5 +174,5 @@ object Window { * * @since 1.4.0 */ -@Experimental +@InterfaceStability.Stable class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8ebed399bf2d0..1e85b6e7881ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,23 +17,22 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{catalyst, Column} +import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ /** - * :: Experimental :: * A window specification that defines the partitioning, ordering, and frame boundaries. * * Use the static methods in [[Window]] to create a [[WindowSpec]]. * * @since 1.4.0 */ -@Experimental +@InterfaceStability.Stable class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - frame: catalyst.expressions.WindowFrame) { + frame: WindowFrame) { /** * Defines the partitioning columns in a [[WindowSpec]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 60d7b7d0894d0..aa71cb9e3bc85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate._ @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.aggregate._ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving // scalastyle:off object typed { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 5417a0e481158..bc9788d81fe6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ /** - * :: Experimental :: * The base class for implementing user-defined aggregate functions (UDAF). + * + * @since 1.5.0 */ -@Experimental +@InterfaceStability.Stable abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -43,6 +44,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * The name of a field of this [[StructType]] is only used to identify the corresponding * input argument. Users can choose names to identify the input arguments. + * + * @since 1.5.0 */ def inputSchema: StructType @@ -60,17 +63,23 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * The name of a field of this [[StructType]] is only used to identify the corresponding * buffer value. Users can choose names to identify the input arguments. + * + * @since 1.5.0 */ def bufferSchema: StructType /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + * + * @since 1.5.0 */ def dataType: DataType /** * Returns true iff this function is deterministic, i.e. given the same input, * always return the same output. + * + * @since 1.5.0 */ def deterministic: Boolean @@ -80,6 +89,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. + * + * @since 1.5.0 */ def initialize(buffer: MutableAggregationBuffer): Unit @@ -87,6 +98,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Updates the given aggregation buffer `buffer` with new input data from `input`. * * This is called once per input row. + * + * @since 1.5.0 */ def update(buffer: MutableAggregationBuffer, input: Row): Unit @@ -94,17 +107,23 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. * * This is called when we merge two partially aggregated data together. + * + * @since 1.5.0 */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given * aggregation buffer. + * + * @since 1.5.0 */ def evaluate(buffer: Row): Any /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. + * + * @since 1.5.0 */ @scala.annotation.varargs def apply(exprs: Column*): Column = { @@ -119,6 +138,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. + * + * @since 1.5.0 */ @scala.annotation.varargs def distinct(exprs: Column*): Column = { @@ -132,12 +153,13 @@ abstract class UserDefinedAggregateFunction extends Serializable { } /** - * :: Experimental :: * A [[Row]] representing a mutable aggregation buffer. * * This is not meant to be extended outside of Spark. + * + * @since 1.5.0 */ -@Experimental +@InterfaceStability.Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index de4943152720c..5f1efd22d8204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,7 +37,6 @@ import org.apache.spark.util.Utils /** - * :: Experimental :: * Functions available for DataFrame operations. * * @groupname udf_funcs UDF functions @@ -53,8 +52,7 @@ import org.apache.spark.util.Utils * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable // scalastyle:off object functions { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 192083e2ea5f5..dc31f3bc323f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -56,7 +57,7 @@ object SQLConf { val WAREHOUSE_PATH = SQLConfigBuilder("spark.sql.warehouse.dir") .doc("The default location for managed databases and tables.") .stringConf - .createWithDefault("${system:user.dir}/spark-warehouse") + .createWithDefault(Utils.resolveURI("spark-warehouse").toString) val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations") .internal() @@ -265,9 +266,28 @@ object SQLConf { val HIVE_METASTORE_PARTITION_PRUNING = SQLConfigBuilder("spark.sql.hive.metastorePartitionPruning") .doc("When true, some predicates will be pushed down into the Hive metastore so that " + - "unmatching partitions can be eliminated earlier.") + "unmatching partitions can be eliminated earlier. This only affects Hive tables " + + "not converted to filesource relations (see HiveUtils.CONVERT_METASTORE_PARQUET and " + + "HiveUtils.CONVERT_METASTORE_ORC for more information).") .booleanConf - .createWithDefault(false) + .createWithDefault(true) + + val HIVE_MANAGE_FILESOURCE_PARTITIONS = + SQLConfigBuilder("spark.sql.hive.manageFilesourcePartitions") + .doc("When true, enable metastore partition management for file source tables as well. " + + "This includes both datasource and converted Hive tables. When partition managment " + + "is enabled, datasource tables store partition in the Hive metastore, and use the " + + "metastore to prune partitions during query planning.") + .booleanConf + .createWithDefault(true) + + val HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE = + SQLConfigBuilder("spark.sql.hive.filesourcePartitionFileCacheSize") + .doc("When nonzero, enable caching of partition file metadata in memory. All tables share " + + "a cache that can use up to specified num bytes for file metadata. This conf only " + + "has an effect when hive filesource partition management is enabled.") + .longConf + .createWithDefault(250 * 1024 * 1024) val OPTIMIZER_METADATA_ONLY = SQLConfigBuilder("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + @@ -332,13 +352,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val PARTITION_MAX_FILES = - SQLConfigBuilder("spark.sql.sources.maxConcurrentWrites") - .doc("The maximum number of concurrent files to open before falling back on sorting when " + - "writing out files using dynamic partitioning.") - .intConf - .createWithDefault(1) - val BUCKETING_ENABLED = SQLConfigBuilder("spark.sql.sources.bucketing.enabled") .doc("When false, we will treat bucketed table as normal table") .booleanConf @@ -381,14 +394,6 @@ object SQLConf { .intConf .createWithDefault(32) - // Whether to perform eager analysis when constructing a dataframe. - // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = SQLConfigBuilder("spark.sql.eagerAnalysis") - .internal() - .doc("When true, eagerly applies query analysis on DataFrame operations.") - .booleanConf - .createWithDefault(true) - // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = @@ -569,10 +574,17 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val STREAMING_METRICS_ENABLED = + SQLConfigBuilder("spark.sql.streaming.metricsEnabled") + .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") + .booleanConf + .createWithDefault(false) + val NDV_MAX_ERROR = SQLConfigBuilder("spark.sql.statistics.ndv.maxError") .internal() - .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " + + "column level statistics.") .doubleConf .createWithDefault(0.05) @@ -635,6 +647,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) @@ -667,6 +681,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + def manageFilesourcePartitions: Boolean = getConf(HIVE_MANAGE_FILESOURCE_PARTITIONS) + + def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) @@ -723,15 +741,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def partitionColumnTypeInferenceEnabled: Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) - def partitionMaxFiles: Int = getConf(PARTITION_MAX_FILES) - def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) - def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) - def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) @@ -916,8 +930,11 @@ object StaticSQLConf { .intConf .createWithDefault(4000) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, + // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildConf("spark.sql.debug") .internal() + .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") .booleanConf .createWithDefault(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8dd4b8f662713..dec316be7aea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ /** @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi +@InterfaceStability.Evolving case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** @@ -53,6 +54,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi +@InterfaceStability.Evolving abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. @@ -142,6 +144,7 @@ abstract class JdbcDialect extends Serializable { * sure to register your dialects first. */ @DeveloperApi +@InterfaceStability.Evolving object JdbcDialects { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 13c0766219a8e..e0494dfd9343b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import org.apache.spark.annotation.InterfaceStability + //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -26,6 +28,7 @@ package org.apache.spark.sql.sources * * @since 1.3.0 */ +@InterfaceStability.Stable abstract class Filter { /** * List of columns that are referenced by this filter. @@ -45,6 +48,7 @@ abstract class Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -56,6 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * * @since 1.5.0 */ +@InterfaceStability.Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -66,6 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -76,6 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -86,6 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -96,6 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -105,6 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class In(attribute: String, values: Array[Any]) extends Filter { override def hashCode(): Int = { var h = attribute.hashCode @@ -131,6 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -140,6 +151,7 @@ case class IsNull(attribute: String) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -149,6 +161,7 @@ case class IsNotNull(attribute: String) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -158,6 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -167,6 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references } @@ -177,6 +192,7 @@ case class Not(child: Filter) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -187,6 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -197,6 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6484c782b5d15..15a48072525b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -27,7 +27,6 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * ::DeveloperApi:: * Data sources should implement this trait so that they can register an alias to their data source. * This allows users to give the data source alias as the format type over the fully qualified * class name. @@ -36,7 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.5.0 */ -@DeveloperApi +@InterfaceStability.Stable trait DataSourceRegister { /** @@ -53,7 +52,6 @@ trait DataSourceRegister { } /** - * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When * Spark SQL is given a DDL operation with a USING clause specified (to specify the implemented * RelationProvider), this interface is used to pass in the parameters specified by a user. @@ -67,7 +65,7 @@ trait DataSourceRegister { * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -78,7 +76,6 @@ trait RelationProvider { } /** - * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema. When Spark SQL is given a DDL operation with a USING clause specified ( * to specify the implemented SchemaRelationProvider) and a user defined schema, this interface @@ -98,7 +95,7 @@ trait RelationProvider { * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -114,17 +111,26 @@ trait SchemaRelationProvider { /** * ::Experimental:: * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. + * + * @since 2.0.0 */ @Experimental +@InterfaceStability.Unstable trait StreamSourceProvider { - /** Returns the name and schema of the source that can be used to continually read data. */ + /** + * Returns the name and schema of the source that can be used to continually read data. + * @since 2.0.0 + */ def sourceSchema( sqlContext: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) + /** + * @since 2.0.0 + */ def createSource( sqlContext: SQLContext, metadataPath: String, @@ -136,8 +142,11 @@ trait StreamSourceProvider { /** * ::Experimental:: * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + * + * @since 2.0.0 */ @Experimental +@InterfaceStability.Unstable trait StreamSinkProvider { def createSink( sqlContext: SQLContext, @@ -149,7 +158,7 @@ trait StreamSinkProvider { /** * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait CreatableRelationProvider { /** * Save the DataFrame to the destination and return a relation with the given parameters based on @@ -173,7 +182,6 @@ trait CreatableRelationProvider { } /** - * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various @@ -185,7 +193,7 @@ trait CreatableRelationProvider { * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -231,30 +239,27 @@ abstract class BaseRelation { } /** - * ::DeveloperApi:: * A BaseRelation that can produce all of its tuples as an RDD of Row objects. * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait TableScan { def buildScan(): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * @@ -267,13 +272,12 @@ trait PrunedScan { * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can be used to insert data into it through the insert method. * If overwrite in insert method is true, the old data in the relation should be overwritten with * the new data. If overwrite in insert method is false, the new data should be appended. @@ -290,7 +294,7 @@ trait PrunedFilteredScan { * * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } @@ -306,6 +310,7 @@ trait InsertableRelation { * @since 1.3.0 */ @Experimental +@InterfaceStability.Unstable trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 87b73062180e4..40b482e4c01a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -134,7 +134,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file stream ([[http://jsonlines.org/ JSON Lines text format or newline-delimited + * JSON]]) and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala index de1efe961f8bd..ab19602207ad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala @@ -17,18 +17,50 @@ package org.apache.spark.sql.streaming +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent /** * :: Experimental :: - * Status and metrics of a streaming [[Sink]]. + * Status and metrics of a streaming sink. * - * @param description Description of the source corresponding to this status - * @param offsetDesc Description of the current offset up to which data has been written by the sink + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offsets up to which data has been written + * by the sink. * @since 2.0.0 */ @Experimental -class SinkStatus private[sql]( +class SinkStatus private( val description: String, - val offsetDesc: String) + val offsetDesc: String) { + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = + "Status of sink " + indent(prettyString).trim + + private[sql] def jsonValue: JValue = { + ("description" -> JString(description)) ~ + ("offsetDesc" -> JString(offsetDesc)) + } + + private[sql] def prettyString: String = { + s"""$description + |Committed offsets: $offsetDesc + |""".stripMargin + } +} + +/** Companion object, primarily for creating SinkStatus instances internally */ +private[sql] object SinkStatus { + def apply(desc: String, offsetDesc: String): SinkStatus = new SinkStatus(desc, offsetDesc) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala index bd0c8485e4fdd..cfdf11370e06d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala @@ -17,18 +17,79 @@ package org.apache.spark.sql.streaming +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent +import org.apache.spark.util.JsonProtocol /** * :: Experimental :: - * Status and metrics of a streaming [[Source]]. + * Status and metrics of a streaming Source. * - * @param description Description of the source corresponding to this status - * @param offsetDesc Description of the current [[Source]] offset if known + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offset if known. + * @param inputRate Current rate (rows/sec) at which data is being generated by the source. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * the source. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. * @since 2.0.0 */ @Experimental -class SourceStatus private[sql] ( +class SourceStatus private( val description: String, - val offsetDesc: Option[String]) + val offsetDesc: String, + val inputRate: Double, + val processingRate: Double, + val triggerDetails: ju.Map[String, String]) { + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = + "Status of source " + indent(prettyString).trim + + private[sql] def jsonValue: JValue = { + ("description" -> JString(description)) ~ + ("offsetDesc" -> JString(offsetDesc)) ~ + ("inputRate" -> JDouble(inputRate)) ~ + ("processingRate" -> JDouble(processingRate)) ~ + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + } + + private[sql] def prettyString: String = { + val triggerDetailsLines = + triggerDetails.asScala.map { case (k, v) => s"$k: $v" } + s"""$description + |Available offset: $offsetDesc + |Input rate: $inputRate rows/sec + |Processing rate: $processingRate rows/sec + |Trigger details: + |""".stripMargin + indent(triggerDetailsLines) + } +} + +/** Companion object, primarily for creating SourceStatus instances internally */ +private[sql] object SourceStatus { + def apply( + desc: String, + offsetDesc: String, + inputRate: Double, + processingRate: Double, + triggerDetails: Map[String, String]): SourceStatus = { + new SourceStatus(desc, offsetDesc, inputRate, processingRate, triggerDetails.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 91f0a1e3446a1..0a85414451981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -62,13 +62,24 @@ trait StreamingQuery { */ def exception: Option[StreamingQueryException] + /** + * Returns the current status of the query. + * @since 2.0.2 + */ + def status: StreamingQueryStatus + /** * Returns current status of all the sources. * @since 2.0.0 */ + @deprecated("use status.sourceStatuses", "2.0.2") def sourceStatuses: Array[SourceStatus] - /** Returns current status of the sink. */ + /** + * Returns current status of the sink. + * @since 2.0.0 + */ + @deprecated("use status.sinkStatus", "2.0.2") def sinkStatus: SinkStatus /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 8a8855d85a4c7..9e311fae842be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -41,7 +41,7 @@ abstract class StreamingQueryListener { * don't block this method as it will block your query. * @since 2.0.0 */ - def onQueryStarted(queryStarted: QueryStarted): Unit + def onQueryStarted(event: QueryStartedEvent): Unit /** * Called when there is some status update (ingestion rate updated, etc.) @@ -49,16 +49,16 @@ abstract class StreamingQueryListener { * @note This method is asynchronous. The status in [[StreamingQuery]] will always be * latest no matter when this method is called. Therefore, the status of [[StreamingQuery]] * may be changed before/when you process the event. E.g., you may find [[StreamingQuery]] - * is terminated when you are processing [[QueryProgress]]. + * is terminated when you are processing [[QueryProgressEvent]]. * @since 2.0.0 */ - def onQueryProgress(queryProgress: QueryProgress): Unit + def onQueryProgress(event: QueryProgressEvent): Unit /** * Called when a query is stopped, with or without error. * @since 2.0.0 */ - def onQueryTerminated(queryTerminated: QueryTerminated): Unit + def onQueryTerminated(event: QueryTerminatedEvent): Unit } @@ -84,7 +84,7 @@ object StreamingQueryListener { * @since 2.0.0 */ @Experimental - class QueryStarted private[sql](val queryInfo: StreamingQueryInfo) extends Event + class QueryStartedEvent private[sql](val queryStatus: StreamingQueryStatus) extends Event /** * :: Experimental :: @@ -92,19 +92,19 @@ object StreamingQueryListener { * @since 2.0.0 */ @Experimental - class QueryProgress private[sql](val queryInfo: StreamingQueryInfo) extends Event + class QueryProgressEvent private[sql](val queryStatus: StreamingQueryStatus) extends Event /** * :: Experimental :: * Event representing that termination of a query * - * @param queryInfo Information about the status of the query. + * @param queryStatus Information about the status of the query. * @param exception The exception message of the [[StreamingQuery]] if the query was terminated * with an exception. Otherwise, it will be `None`. * @since 2.0.0 */ @Experimental - class QueryTerminated private[sql]( - val queryInfo: StreamingQueryInfo, + class QueryTerminatedEvent private[sql]( + val queryStatus: StreamingQueryStatus, val exception: Option[String]) extends Event } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala new file mode 100644 index 0000000000000..a50b0d96c13f7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset} +import org.apache.spark.util.JsonProtocol + +/** + * :: Experimental :: + * A class used to report information about the progress of a [[StreamingQuery]]. + * + * @param name Name of the query. This name is unique across all active queries. + * @param id Id of the query. This id is unique across + * all queries that have been started in the current process. + * @param timestamp Timestamp (ms) of when this query was generated. + * @param inputRate Current rate (rows/sec) at which data is being generated by all the sources. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * all the sources. + * @param latency Current average latency between the data being available in source and the sink + * writing the corresponding output. + * @param sourceStatuses Current statuses of the sources. + * @param sinkStatus Current status of the sink. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. + * @since 2.0.0 + */ +@Experimental +class StreamingQueryStatus private( + val name: String, + val id: Long, + val timestamp: Long, + val inputRate: Double, + val processingRate: Double, + val latency: Option[Double], + val sourceStatuses: Array[SourceStatus], + val sinkStatus: SinkStatus, + val triggerDetails: ju.Map[String, String]) { + + import StreamingQueryStatus._ + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = { + val sourceStatusLines = sourceStatuses.zipWithIndex.map { case (s, i) => + s"Source ${i + 1} - " + indent(s.prettyString).trim + } + val sinkStatusLines = sinkStatus.prettyString.trim + val triggerDetailsLines = triggerDetails.asScala.map { case (k, v) => s"$k: $v" }.toSeq.sorted + val numSources = sourceStatuses.length + val numSourcesString = s"$numSources source" + { if (numSources > 1) "s" else "" } + + val allLines = + s"""|Query id: $id + |Status timestamp: $timestamp + |Input rate: $inputRate rows/sec + |Processing rate $processingRate rows/sec + |Latency: ${latency.getOrElse("-")} ms + |Trigger details: + |${indent(triggerDetailsLines)} + |Source statuses [$numSourcesString]: + |${indent(sourceStatusLines)} + |Sink status - ${indent(sinkStatusLines).trim}""".stripMargin + + s"Status of query '$name'\n${indent(allLines)}" + } + + private[sql] def jsonValue: JValue = { + ("name" -> JString(name)) ~ + ("id" -> JInt(id)) ~ + ("timestamp" -> JInt(timestamp)) ~ + ("inputRate" -> JDouble(inputRate)) ~ + ("processingRate" -> JDouble(processingRate)) ~ + ("latency" -> latency.map(JDouble).getOrElse(JNothing)) ~ + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + ("sourceStatuses" -> JArray(sourceStatuses.map(_.jsonValue).toList)) ~ + ("sinkStatus" -> sinkStatus.jsonValue) + } +} + +/** Companion object, primarily for creating StreamingQueryInfo instances internally */ +private[sql] object StreamingQueryStatus { + def apply( + name: String, + id: Long, + timestamp: Long, + inputRate: Double, + processingRate: Double, + latency: Option[Double], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus, + triggerDetails: Map[String, String]): StreamingQueryStatus = { + new StreamingQueryStatus(name, id, timestamp, inputRate, processingRate, + latency, sourceStatuses, sinkStatus, triggerDetails.asJava) + } + + def indent(strings: Iterable[String]): String = strings.map(indent).mkString("\n") + def indent(string: String): String = string.split("\n").map(" " + _).mkString("\n") + + /** Create an instance of status for python testing */ + def testStatus(): StreamingQueryStatus = { + import org.apache.spark.sql.execution.streaming.StreamMetrics._ + StreamingQueryStatus( + name = "query", + id = 1, + timestamp = 123, + inputRate = 15.5, + processingRate = 23.5, + latency = Some(345), + sourceStatuses = Array( + SourceStatus( + desc = "MySource1", + offsetDesc = LongOffset(0).toString, + inputRate = 15.5, + processingRate = 23.5, + triggerDetails = Map( + NUM_SOURCE_INPUT_ROWS -> "100", + SOURCE_GET_OFFSET_LATENCY -> "10", + SOURCE_GET_BATCH_LATENCY -> "20"))), + sinkStatus = SinkStatus( + desc = "MySink", + offsetDesc = CompositeOffset(Some(LongOffset(1)) :: None :: Nil).toString), + triggerDetails = Map( + TRIGGER_ID -> "5", + IS_TRIGGER_ACTIVE -> "true", + IS_DATA_PRESENT_IN_TRIGGER -> "true", + GET_OFFSET_LATENCY -> "10", + GET_BATCH_LATENCY -> "20", + NUM_INPUT_ROWS -> "100" + )) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java similarity index 52% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java index 1af2668817eae..b90224f2ae397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -15,23 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package test.org.apache.spark.sql; -import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.api.java.UDF1; /** - * :: Experimental :: - * A class used to report information about the progress of a [[StreamingQuery]]. - * - * @param name The [[StreamingQuery]] name. This name is unique across all active queries. - * @param id The [[StreamingQuery]] id. This id is unique across - * all queries that have been started in the current process. - * @param sourceStatuses The current statuses of the [[StreamingQuery]]'s sources. - * @param sinkStatus The current status of the [[StreamingQuery]]'s sink. + * It is used for register Java UDF from PySpark */ -@Experimental -class StreamingQueryInfo private[sql]( - val name: String, - val id: Long, - val sourceStatuses: Seq[SourceStatus], - val sinkStatus: SinkStatus) +public class JavaStringLength implements UDF1 { + @Override + public Integer call(String str) throws Exception { + return new Integer(str.length()); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 2274912521a56..8bf3278c43880 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -87,4 +87,25 @@ public Integer call(String str1, String str2) { Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } + + public static class StringLengthTest implements UDF2 { + @Override + public Integer call(String str1, String str2) throws Exception { + return new Integer(str1.length() + str2.length()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void udf3Test() { + spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), + DataTypes.IntegerType); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + + // returnType is not provided + spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null); + result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 4038a0da41d2b..984321ab795fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -71,6 +71,12 @@ select sort_array(timestamp_array) from primitive_arrays; +-- sort_array with an invalid string literal for the argument of sort order. +select sort_array(array('b', 'd'), '1'); + +-- sort_array with an invalid null literal casted as boolean for the argument of sort order. +select sort_array(array('b', 'd'), cast(NULL as boolean)); + -- size select size(boolean_array), diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql new file mode 100644 index 0000000000000..f8135389a9e5a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -0,0 +1,57 @@ +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2) +AS testData(a, b); + +-- CUBE on overlapping columns +SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH CUBE; + +SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH CUBE; + +-- ROLLUP on overlapping columns +SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH ROLLUP; + +SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH ROLLUP; + +CREATE OR REPLACE TEMPORARY VIEW courseSales AS SELECT * FROM VALUES +("dotNET", 2012, 10000), ("Java", 2012, 20000), ("dotNET", 2012, 5000), ("dotNET", 2013, 48000), ("Java", 2013, 30000) +AS courseSales(course, year, earnings); + +-- ROLLUP +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY course, year; + +-- CUBE +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, year; + +-- GROUPING SETS +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year); +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course); +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year); + +-- GROUPING SETS with aggregate functions containing groupBy columns +SELECT course, SUM(earnings) AS sum FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum; +SELECT course, SUM(earnings) AS sum, GROUPING_ID(course, earnings) FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum; + +-- GROUPING/GROUPING_ID +SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales +GROUP BY CUBE(course, year); +SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); + +-- GROUPING/GROUPING_ID in having clause +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; + +-- GROUPING/GROUPING_ID in orderBy clause +SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year; +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year; +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql new file mode 100644 index 0000000000000..3894082255088 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql @@ -0,0 +1,58 @@ +CREATE DATABASE showdb; + +USE showdb; + +CREATE TABLE showcolumn1 (col1 int, `col 2` int); +CREATE TABLE showcolumn2 (price int, qty int) partitioned by (year int, month int); +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet; +CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`; + + +-- only table name +SHOW COLUMNS IN showcolumn1; + +-- qualified table name +SHOW COLUMNS IN showdb.showcolumn1; + +-- table name and database name +SHOW COLUMNS IN showcolumn1 FROM showdb; + +-- partitioned table +SHOW COLUMNS IN showcolumn2 IN showdb; + +-- Non-existent table. Raise an error in this case +SHOW COLUMNS IN badtable FROM showdb; + +-- database in table identifier and database name in different case +SHOW COLUMNS IN showdb.showcolumn1 from SHOWDB; + +-- different database name in table identifier and database name. +-- Raise an error in this case. +SHOW COLUMNS IN showdb.showcolumn1 FROM baddb; + +-- show column on temporary view +SHOW COLUMNS IN showcolumn3; + +-- error temp view can't be qualified with a database +SHOW COLUMNS IN showdb.showcolumn3; + +-- error temp view can't be qualified with a database +SHOW COLUMNS IN showcolumn3 FROM showdb; + +-- error global temp view needs to be qualified +SHOW COLUMNS IN showcolumn4; + +-- global temp view qualified with database +SHOW COLUMNS IN global_temp.showcolumn4; + +-- global temp view qualified with database +SHOW COLUMNS IN showcolumn4 FROM global_temp; + +DROP TABLE showcolumn1; +DROP TABLE showColumn2; +DROP VIEW showcolumn3; +DROP VIEW global_temp.showcolumn4; + +use default; + +DROP DATABASE showdb; diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 4a1d149c1f362..499a3d5fb72f6 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -124,8 +124,23 @@ struct,sort_array(tinyint_array, -- !query 8 output [true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0] - -- !query 9 +select sort_array(array('b', 'd'), '1') +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + +-- !query 10 +select sort_array(array('b', 'd'), cast(NULL as boolean)) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + +-- !query 11 select size(boolean_array), size(tinyint_array), @@ -138,7 +153,7 @@ select size(date_array), size(timestamp_array) from primitive_arrays --- !query 9 schema +-- !query 11 schema struct --- !query 9 output +-- !query 11 output 1 2 2 2 2 2 2 2 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out new file mode 100644 index 0000000000000..825e8f5488c8b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -0,0 +1,330 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 26 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2) +AS testData(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH CUBE +-- !query 1 schema +struct<(a + b):int,b:int,sum((a - b)):bigint> +-- !query 1 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL 1 3 +NULL 2 0 +NULL NULL 3 + + +-- !query 2 +SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH CUBE +-- !query 2 schema +struct +-- !query 2 output +1 1 1 +1 2 2 +1 NULL 3 +2 1 1 +2 2 2 +2 NULL 3 +3 1 1 +3 2 2 +3 NULL 3 +NULL 1 3 +NULL 2 6 +NULL NULL 9 + + +-- !query 3 +SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH ROLLUP +-- !query 3 schema +struct<(a + b):int,b:int,sum((a - b)):bigint> +-- !query 3 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL NULL 3 + + +-- !query 4 +SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH ROLLUP +-- !query 4 schema +struct +-- !query 4 output +1 1 1 +1 2 2 +1 NULL 3 +2 1 1 +2 2 2 +2 NULL 3 +3 1 1 +3 2 2 +3 NULL 3 +NULL NULL 9 + + +-- !query 5 +CREATE OR REPLACE TEMPORARY VIEW courseSales AS SELECT * FROM VALUES +("dotNET", 2012, 10000), ("Java", 2012, 20000), ("dotNET", 2012, 5000), ("dotNET", 2013, 48000), ("Java", 2013, 30000) +AS courseSales(course, year, earnings) +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY course, year +-- !query 6 schema +struct +-- !query 6 output +NULL NULL 113000 +Java NULL 50000 +Java 2012 20000 +Java 2013 30000 +dotNET NULL 63000 +dotNET 2012 15000 +dotNET 2013 48000 + + +-- !query 7 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, year +-- !query 7 schema +struct +-- !query 7 output +NULL NULL 113000 +NULL 2012 35000 +NULL 2013 78000 +Java NULL 50000 +Java 2012 20000 +Java 2013 30000 +dotNET NULL 63000 +dotNET 2012 15000 +dotNET 2013 48000 + + +-- !query 8 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year) +-- !query 8 schema +struct +-- !query 8 output +Java NULL 50000 +NULL 2012 35000 +NULL 2013 78000 +dotNET NULL 63000 + + +-- !query 9 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course) +-- !query 9 schema +struct +-- !query 9 output +Java NULL 50000 +dotNET NULL 63000 + + +-- !query 10 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year) +-- !query 10 schema +struct +-- !query 10 output +NULL 2012 35000 +NULL 2013 78000 + + +-- !query 11 +SELECT course, SUM(earnings) AS sum FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +-- !query 11 schema +struct +-- !query 11 output +NULL 113000 +Java 20000 +Java 30000 +Java 50000 +dotNET 5000 +dotNET 10000 +dotNET 48000 +dotNET 63000 + + +-- !query 12 +SELECT course, SUM(earnings) AS sum, GROUPING_ID(course, earnings) FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +-- !query 12 schema +struct +-- !query 12 output +NULL 113000 3 +Java 20000 0 +Java 30000 0 +Java 50000 1 +dotNET 5000 0 +dotNET 10000 0 +dotNET 48000 0 +dotNET 63000 1 + + +-- !query 13 +SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales +GROUP BY CUBE(course, year) +-- !query 13 schema +struct +-- !query 13 output +Java 2012 0 0 0 +Java 2013 0 0 0 +Java NULL 0 1 1 +NULL 2012 1 0 2 +NULL 2013 1 0 2 +NULL NULL 1 1 3 +dotNET 2012 0 0 0 +dotNET 2013 0 0 0 +dotNET NULL 0 1 1 + + +-- !query 14 +SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +grouping() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 15 +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 16 +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; + + +-- !query 17 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +-- !query 17 schema +struct +-- !query 17 output +Java NULL +NULL NULL +dotNET NULL + + +-- !query 18 +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 19 +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0 +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 20 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; + + +-- !query 21 +SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year +-- !query 21 schema +struct +-- !query 21 output +Java 2012 0 0 +Java 2013 0 0 +dotNET 2012 0 0 +dotNET 2013 0 0 +Java NULL 0 1 +dotNET NULL 0 1 +NULL 2012 1 0 +NULL 2013 1 0 +NULL NULL 1 1 + + +-- !query 22 +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year +-- !query 22 schema +struct +-- !query 22 output +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 + + +-- !query 23 +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course) +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 24 +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course) +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 25 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out new file mode 100644 index 0000000000000..832e6e25bb2bd --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out @@ -0,0 +1,217 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 25 + + +-- !query 0 +CREATE DATABASE showdb +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE showdb +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE showcolumn1 (col1 int, `col 2` int) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TABLE showcolumn2 (price int, qty int) partitioned by (year int, month int) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5` +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SHOW COLUMNS IN showcolumn1 +-- !query 6 schema +struct +-- !query 6 output +col 2 +col1 + + +-- !query 7 +SHOW COLUMNS IN showdb.showcolumn1 +-- !query 7 schema +struct +-- !query 7 output +col 2 +col1 + + +-- !query 8 +SHOW COLUMNS IN showcolumn1 FROM showdb +-- !query 8 schema +struct +-- !query 8 output +col 2 +col1 + + +-- !query 9 +SHOW COLUMNS IN showcolumn2 IN showdb +-- !query 9 schema +struct +-- !query 9 output +month +price +qty +year + + +-- !query 10 +SHOW COLUMNS IN badtable FROM showdb +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'badtable' not found in database 'showdb'; + + +-- !query 11 +SHOW COLUMNS IN showdb.showcolumn1 from SHOWDB +-- !query 11 schema +struct +-- !query 11 output +col 2 +col1 + + +-- !query 12 +SHOW COLUMNS IN showdb.showcolumn1 FROM baddb +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +SHOW COLUMNS with conflicting databases: 'baddb' != 'showdb'; + + +-- !query 13 +SHOW COLUMNS IN showcolumn3 +-- !query 13 schema +struct +-- !query 13 output +col 4 +col3 + + +-- !query 14 +SHOW COLUMNS IN showdb.showcolumn3 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'showcolumn3' not found in database 'showdb'; + + +-- !query 15 +SHOW COLUMNS IN showcolumn3 FROM showdb +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'showcolumn3' not found in database 'showdb'; + + +-- !query 16 +SHOW COLUMNS IN showcolumn4 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'showcolumn4' not found in database 'showdb'; + + +-- !query 17 +SHOW COLUMNS IN global_temp.showcolumn4 +-- !query 17 schema +struct +-- !query 17 output +col 5 +col1 + + +-- !query 18 +SHOW COLUMNS IN showcolumn4 FROM global_temp +-- !query 18 schema +struct +-- !query 18 output +col 5 +col1 + + +-- !query 19 +DROP TABLE showcolumn1 +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +DROP TABLE showColumn2 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP VIEW showcolumn3 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW global_temp.showcolumn4 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +use default +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +DROP DATABASE showdb +-- !query 24 schema +struct<> +-- !query 24 output + diff --git a/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt new file mode 100644 index 0000000000000..aa7e9a8c20c43 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt @@ -0,0 +1,4 @@ +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@2b85b3a5","offsetDesc":"[#0]"}}} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@2b85b3a5","offsetDesc":"[#0]"}},"exception":null,"stackTrace":[]} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@514502dc","offsetDesc":"[-]"}},"exception":"Query hello terminated with exception: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:85)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:","stackTrace":[{"methodName":"org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches","fileName":"StreamExecution.scala","lineNumber":208,"className":"org.apache.spark.sql.execution.streaming.StreamExecution","nativeMethod":false},{"methodName":"run","fileName":"StreamExecution.scala","lineNumber":120,"className":"org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1","nativeMethod":false}]} +{"Event":"SparkListenerApplicationEnd","Timestamp":1477593059313} diff --git a/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt new file mode 100644 index 0000000000000..646cf107183b4 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt @@ -0,0 +1,4 @@ +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@10e5ec94","offsetDesc":"[#0]"}}} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@10e5ec94","offsetDesc":"[#0]"}},"exception":null} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@70c61dc8","offsetDesc":"[-]"}},"exception":"org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:86)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:\n\tat org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1454)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1442)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1441)\n\tat scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)\n\tat scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)\n\tat org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1441)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:811)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:811)\n\tat scala.Option.foreach(Option.scala:257)\n\tat org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:811)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1667)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1622)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1611)\n\tat org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)\n\tat org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:632)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1890)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1903)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1916)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1930)\n\tat org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:912)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)\n\tat org.apache.spark.rdd.RDD.withScope(RDD.scala:358)\n\tat org.apache.spark.rdd.RDD.collect(RDD.scala:911)\n\tat org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:290)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2193)\n\tat org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)\n\tat org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2546)\n\tat org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2192)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset.withCallback(Dataset.scala:2559)\n\tat org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset.collect(Dataset.scala:2173)\n\tat org.apache.spark.sql.execution.streaming.MemorySink.addBatch(memory.scala:154)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatch(StreamExecution.scala:366)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution$$anonfun$org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches$1.apply$mcZ$sp(StreamExecution.scala:197)\n\tat org.apache.spark.sql.execution.streaming.ProcessingTimeExecutor.execute(TriggerExecutor.scala:43)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches(StreamExecution.scala:187)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:124)\nCaused by: java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:86)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1477701734609} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..586a0fffeb7a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -273,7 +273,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("sort_array function") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), - (Array[Int](), Array[String]()), + (Array.empty[Int], Array.empty[String]), (null, null) ).toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 73026c749db45..1383208874a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) } + // test approxQuantile on NaN values + val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + assert(resNaN.count(_.isNaN) === 0) } test("crosstab") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 16cc368208485..33b3b78c9f04f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import java.util.UUID import scala.util.Random @@ -1598,6 +1599,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") { + withTable("bar") { + withTempView("foo") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + sql("select 0 as id").createOrReplaceTempView("foo") + val df = sql("select * from foo group by id") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + df.write.mode("overwrite").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("json"), + "the expected table is a data source table using json") + } + } + } + } + test("copy results for sampling with replacement") { val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") val sampleDf = df.sample(true, 2.00) @@ -1615,4 +1634,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { qe.assertAnalyzed() } } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } + + test("SPARK-18070 binary operator should not consider nullability when comparing input types") { + val rows = Seq(Row(Seq(1), Seq(1))) + val schema = new StructType() + .add("array1", ArrayType(IntegerType)) + .add("array2", ArrayType(IntegerType, containsNull = false)) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + assert(df.filter($"array1" === $"array2").count() == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 8d5e9645df894..e0561ee2797a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -19,11 +19,32 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.storage.StorageLevel class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("get storage level") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + // default storage level + ds1.persist() + ds2.cache() + assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK) + assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK) + // unpersist + ds1.unpersist() + assert(ds1.storageLevel == StorageLevel.NONE) + // non-default storage level + ds1.persist(StorageLevel.MEMORY_ONLY_2) + assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2) + // joined Dataset should not be persisted + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + assert(joined.storageLevel == StorageLevel.NONE) + } + test("persist and unpersist") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) val cached = ds.cache() @@ -37,8 +58,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(cached).isEmpty, - "The Dataset should not be cached.") + assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -55,11 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds1).isEmpty, - "The Dataset ds1 should not be cached.") + assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds2).isEmpty, - "The Dataset ds2 should not be cached.") + assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -74,10 +92,8 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds).isEmpty, - "The Dataset ds should not be cached.") + assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.") agged.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(agged).isEmpty, - "The Dataset agged should not be cached.") + assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5fce9b4fe97ea..cc367acae2ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "a", "30", "b", "3", "c", "1") } + test("groupBy function, mapValues, flatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val keyValue = ds.groupByKey(_._1).mapValues(_._2) + val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + + val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value")) + val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + } + test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupByKey(_.length).reduceGroups(_ + _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ee8c959eeb4d..1a43d0b2205ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.{AccumulatorSuite, SparkException} -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -1106,6 +1103,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17863: SELECT distinct does not work correctly if order by missing attribute") { + checkAnswer( + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by a, b + |""".stripMargin), + Row(1, 2) :: Nil) + + val error = intercept[AnalysisException] { + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by struct.a, struct.b + |""".stripMargin) + } + assert(error.message contains "cannot resolve '`struct.a`' given input columns: [a, b]") + + } + test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( @@ -1984,195 +2005,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } - test("rollup") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + - " order by course, year"), - Row(null, null, 113000.0) :: - Row("Java", null, 50000.0) :: - Row("Java", 2012, 20000.0) :: - Row("Java", 2013, 30000.0) :: - Row("dotNET", null, 63000.0) :: - Row("dotNET", 2012, 15000.0) :: - Row("dotNET", 2013, 48000.0) :: Nil - ) - } - - test("grouping sets when aggregate functions containing groupBy columns") { - checkAnswer( - sql("select course, sum(earnings) as sum from courseSales group by course, earnings " + - "grouping sets((), (course), (course, earnings)) " + - "order by course, sum"), - Row(null, 113000.0) :: - Row("Java", 20000.0) :: - Row("Java", 30000.0) :: - Row("Java", 50000.0) :: - Row("dotNET", 5000.0) :: - Row("dotNET", 10000.0) :: - Row("dotNET", 48000.0) :: - Row("dotNET", 63000.0) :: Nil - ) - - checkAnswer( - sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " + - "group by course, earnings grouping sets((), (course), (course, earnings)) " + - "order by course, sum"), - Row(null, 113000.0, 3) :: - Row("Java", 20000.0, 0) :: - Row("Java", 30000.0, 0) :: - Row("Java", 50000.0, 1) :: - Row("dotNET", 5000.0, 0) :: - Row("dotNET", 10000.0, 0) :: - Row("dotNET", 48000.0, 0) :: - Row("dotNET", 63000.0, 1) :: Nil - ) - } - - test("cube") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), - Row("Java", 2012, 20000.0) :: - Row("Java", 2013, 30000.0) :: - Row("Java", null, 50000.0) :: - Row("dotNET", 2012, 15000.0) :: - Row("dotNET", 2013, 48000.0) :: - Row("dotNET", null, 63000.0) :: - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: - Row(null, null, 113000.0) :: Nil - ) - } - - test("grouping sets") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(course, year)"), - Row("Java", null, 50000.0) :: - Row("dotNET", null, 63000.0) :: - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: Nil - ) - - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(course)"), - Row("Java", null, 50000.0) :: - Row("dotNET", null, 63000.0) :: Nil - ) - - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(year)"), - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: Nil - ) - } - - test("grouping and grouping_id") { - checkAnswer( - sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" + - " from courseSales group by cube(course, year)"), - Row("Java", 2012, 0, 0, 0) :: - Row("Java", 2013, 0, 0, 0) :: - Row("Java", null, 0, 1, 1) :: - Row("dotNET", 2012, 0, 0, 0) :: - Row("dotNET", 2013, 0, 0, 0) :: - Row("dotNET", null, 0, 1, 1) :: - Row(null, 2012, 1, 0, 2) :: - Row(null, 2013, 1, 0, 2) :: - Row(null, null, 1, 1, 3) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year, grouping(course) from courseSales group by course, year") - } - assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year, grouping_id(course, year) from courseSales group by course, year") - } - assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year, grouping__id from courseSales group by cube(course, year)") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - - test("grouping and grouping_id in having") { - checkAnswer( - sql("select course, year from courseSales group by cube(course, year)" + - " having grouping(year) = 1 and grouping_id(course, year) > 0"), - Row("Java", null) :: - Row("dotNET", null) :: - Row(null, null) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " having grouping(course) > 0") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " having grouping_id(course, year) > 0") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by cube(course, year)" + - " having grouping__id > 0") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - - test("grouping and grouping_id in sort") { - checkAnswer( - sql("select course, year, grouping(course), grouping(year) from courseSales" + - " group by cube(course, year) order by grouping_id(course, year), course, year"), - Row("Java", 2012, 0, 0) :: - Row("Java", 2013, 0, 0) :: - Row("dotNET", 2012, 0, 0) :: - Row("dotNET", 2013, 0, 0) :: - Row("Java", null, 0, 1) :: - Row("dotNET", null, 0, 1) :: - Row(null, 2012, 1, 0) :: - Row(null, 2013, 1, 0) :: - Row(null, null, 1, 1) :: Nil - ) - - checkAnswer( - sql("select course, year, grouping_id(course, year) from courseSales" + - " group by cube(course, year) order by grouping(course), grouping(year), course, year"), - Row("Java", 2012, 0) :: - Row("Java", 2013, 0) :: - Row("dotNET", 2012, 0) :: - Row("dotNET", 2013, 0) :: - Row("Java", null, 1) :: - Row("dotNET", null, 1) :: - Row(null, 2012, 2) :: - Row(null, 2013, 2) :: - Row(null, null, 3) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " order by grouping(course)") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " order by grouping_id(course, year)") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by cube(course, year)" + - " order by grouping__id") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - test("filter on a grouping column that is not presented in SELECT") { checkAnswer( sql("select count(1) from (select 1 as a) t group by a having a > 0"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 55d5a56f1040a..2d73d9f1fc802 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.command.ShowColumnsCommand import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -196,7 +197,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { assertResult(expected.schema, s"Schema did not match for query #$i\n${expected.sql}") { output.schema } - assertResult(expected.output, s"Result dit not match for query #$i\n${expected.sql}") { + assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { output.output } } @@ -220,6 +221,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { + case a: AnalysisException if a.plan.nonEmpty => + // Do not output the logical plan tree which contains expression IDs. + (StructType(Seq.empty), Seq(a.getClass.getName, a.getSimpleMessage)) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 0ee0547c45591..f1a201abd8da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -150,7 +150,7 @@ class StatisticsColumnSuite extends StatisticsTest { val colStat = ColumnStat(InternalRow( values.count(_.isEmpty).toLong, nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong, + nonNullValues.map(_.length).max.toInt, nonNullValues.distinct.length.toLong)) (f, colStat) } @@ -165,7 +165,7 @@ class StatisticsColumnSuite extends StatisticsTest { val colStat = ColumnStat(InternalRow( values.count(_.isEmpty).toLong, nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong)) + nonNullValues.map(_.length).max.toInt)) (f, colStat) } checkColStats(df, expectedColStatsSeq) @@ -255,10 +255,10 @@ class StatisticsColumnSuite extends StatisticsTest { doubleSeq.distinct.length.toLong)) case StringType => ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) case BinaryType => ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toLong)) + binarySeq.map(_.length).max.toInt)) case BooleanType => ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, booleanSeq.count(_.equals(false)).toLong)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala index a19ea51af7c01..6abcb1f067968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -57,4 +57,6 @@ case class ReferenceSort( override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 679150e9ae4c0..797fe9ffa8be1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, - ShowFunctionsCommand} +import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand, + DescribeTableCommand, ShowFunctionsCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -220,4 +220,18 @@ class SparkSqlParserSuite extends PlanTest { intercept("explain describe tables x", "Unsupported SQL statement") } + + test("SPARK-18106 analyze table") { + assertEqual("analyze table t compute statistics", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("analyze table t compute statistics noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + assertEqual("analyze table t partition (a) compute statistics noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + + intercept("analyze table t compute statistics xxxx", + "Expected `NOSCAN` instead of `xxxx`") + intercept("analyze table t partition (a) compute statistics xxxx", + "Expected `NOSCAN` instead of `xxxx`") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 547fb63813750..d31e7aeb3a78a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -387,20 +387,22 @@ class DDLCommandSuite extends PlanTest { val parsed_table = parser.parsePlan(sql_table) val parsed_view = parser.parsePlan(sql_view) val expected_table = AlterTableRenameCommand( - TableIdentifier("table_name", None), - "new_table_name", + TableIdentifier("table_name"), + TableIdentifier("new_table_name"), isView = false) val expected_view = AlterTableRenameCommand( - TableIdentifier("table_name", None), - "new_table_name", + TableIdentifier("table_name"), + TableIdentifier("new_table_name"), isView = true) comparePlans(parsed_table, expected_table) comparePlans(parsed_view, expected_view) + } - val e = intercept[ParseException]( - parser.parsePlan("ALTER TABLE db1.tbl RENAME TO db1.tbl2") - ) - assert(e.getMessage.contains("Can not specify database in table/view name after RENAME TO")) + test("alter table: rename table with database") { + val query = "ALTER TABLE db1.tbl RENAME TO db1.tbl2" + val plan = parseAs[AlterTableRenameCommand](query) + assert(plan.oldName == TableIdentifier("tbl", Some("db1"))) + assert(plan.newName == TableIdentifier("tbl2", Some("db1"))) } // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); @@ -822,22 +824,24 @@ class DDLCommandSuite extends PlanTest { val sql1 = "SHOW COLUMNS FROM t1" val sql2 = "SHOW COLUMNS IN db1.t1" val sql3 = "SHOW COLUMNS FROM t1 IN db1" - val sql4 = "SHOW COLUMNS FROM db1.t1 IN db1" - val sql5 = "SHOW COLUMNS FROM db1.t1 IN db2" + val sql4 = "SHOW COLUMNS FROM db1.t1 IN db2" val parsed1 = parser.parsePlan(sql1) - val expected1 = ShowColumnsCommand(TableIdentifier("t1", None)) + val expected1 = ShowColumnsCommand(None, TableIdentifier("t1", None)) val parsed2 = parser.parsePlan(sql2) - val expected2 = ShowColumnsCommand(TableIdentifier("t1", Some("db1"))) + val expected2 = ShowColumnsCommand(None, TableIdentifier("t1", Some("db1"))) val parsed3 = parser.parsePlan(sql3) - val parsed4 = parser.parsePlan(sql3) + val expected3 = ShowColumnsCommand(Some("db1"), TableIdentifier("t1", None)) + val parsed4 = parser.parsePlan(sql4) + val expected4 = ShowColumnsCommand(Some("db2"), TableIdentifier("t1", Some("db1"))) + comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected2) - comparePlans(parsed4, expected2) - assertUnsupported(sql5) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) } + test("show partitions") { val sql1 = "SHOW PARTITIONS t1" val sql2 = "SHOW PARTITIONS db1.t1" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 097dc2441351f..9fb0f5384d889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -43,8 +43,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() } finally { - val path = System.getProperty("user.dir") + "/spark-warehouse" - Utils.deleteRecursively(new File(path)) + Utils.deleteRecursively(new File("spark-warehouse")) super.afterEach() } } @@ -96,7 +95,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { .add("b", "int"), provider = Some("hive"), partitionColumnNames = Seq("a", "b"), - createTime = 0L) + createTime = 0L, + partitionProviderIsHive = true) } private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { @@ -116,7 +116,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val catalog = spark.sessionState.catalog withTempDir { tmpDir => - val path = tmpDir.toString + val path = tmpDir.getCanonicalPath // The generated temp path is not qualified. assert(!path.startsWith("file:/")) val uri = tmpDir.toURI @@ -148,7 +148,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("Create/Drop Database") { withTempDir { tmpDir => - val path = tmpDir.toString + val path = tmpDir.getCanonicalPath withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -159,7 +159,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") + val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -184,9 +184,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { try { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbName) - val expectedLocation = - makeQualifiedPath(s"${System.getProperty("user.dir")}/spark-warehouse" + - "/" + s"$dbName.db") + val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db") assert(db1 == CatalogDatabase( dbName, "", @@ -204,7 +202,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") withTempDir { tmpDir => - val path = new Path(tmpDir.toString).toUri.toString + val path = new Path(tmpDir.getCanonicalPath).toUri databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) @@ -227,7 +225,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("Create Database - database already exists") { withTempDir { tmpDir => - val path = tmpDir.toString + val path = tmpDir.getCanonicalPath withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -237,7 +235,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") + val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -476,7 +474,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("Alter/Describe Database") { withTempDir { tmpDir => - val path = tmpDir.toString + val path = tmpDir.getCanonicalPath withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -484,7 +482,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") + val location = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") sql(s"CREATE DATABASE $dbName") @@ -665,16 +663,27 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createDatabase(catalog, "dbx") createDatabase(catalog, "dby") createTable(catalog, tableIdent1) + assert(catalog.listTables("dbx") == Seq(tableIdent1)) - sql("ALTER TABLE dbx.tab1 RENAME TO tab2") + sql("ALTER TABLE dbx.tab1 RENAME TO dbx.tab2") assert(catalog.listTables("dbx") == Seq(tableIdent2)) + + // The database in destination table name can be omitted, and we will use the database of source + // table for it. + sql("ALTER TABLE dbx.tab2 RENAME TO tab1") + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + catalog.setCurrentDatabase("dbx") // rename without explicitly specifying database - sql("ALTER TABLE tab2 RENAME TO tab1") - assert(catalog.listTables("dbx") == Seq(tableIdent1)) + sql("ALTER TABLE tab1 RENAME TO tab2") + assert(catalog.listTables("dbx") == Seq(tableIdent2)) // table to rename does not exist intercept[AnalysisException] { - sql("ALTER TABLE dbx.does_not_exist RENAME TO tab2") + sql("ALTER TABLE dbx.does_not_exist RENAME TO dbx.tab2") + } + // destination database is different + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 RENAME TO dby.tab2") } } @@ -696,6 +705,31 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(spark.table("teachers").collect().toSeq == df.collect().toSeq) } + test("rename temporary table - destination table with database name") { + withTempView("tab1") { + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + } + } + test("rename temporary table") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -736,7 +770,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE tab1 RENAME TO tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to 'tab2': destination table already exists")) + "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) val catalog = spark.sessionState.catalog assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) @@ -890,58 +924,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = spark.sessionState.catalog - val tableIdent = TableIdentifier("tab1", Some("dbx")) - createPartitionedTable(tableIdent, isDatasourceTable = false) - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") - sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) - // rename without explicitly specifying database - catalog.setCurrentDatabase("dbx") - sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) - // table to alter does not exist - intercept[NoSuchTableException] { - sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") - } - // partition to rename does not exist - intercept[NoSuchPartitionException] { - sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") - } + testRenamePartitions(isDatasourceTable = false) } test("alter table: rename partition (datasource table)") { - createPartitionedTable(TableIdentifier("tab1", Some("dbx")), isDatasourceTable = true) - val e = intercept[AnalysisException] { - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") - }.getMessage - assert(e.contains( - "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API")) - // table to alter does not exist - intercept[NoSuchTableException] { - sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") - } - } - - private def createPartitionedTable( - tableIdent: TableIdentifier, - isDatasourceTable: Boolean): Unit = { - val catalog = spark.sessionState.catalog - val part1 = Map("a" -> "1", "b" -> "q") - val part2 = Map("a" -> "2", "b" -> "c") - val part3 = Map("a" -> "3", "b" -> "p") - createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - createTablePartition(catalog, part1, tableIdent) - createTablePartition(catalog, part2, tableIdent) - createTablePartition(catalog, part3, tableIdent) - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + testRenamePartitions(isDatasourceTable = true) } test("show tables") { @@ -1156,7 +1143,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { if (spec.isDefined) { assert(storageFormat.properties.isEmpty) - assert(storageFormat.locationUri.isEmpty) + assert(storageFormat.locationUri === Some(expected)) } else { assert(storageFormat.properties.get("path") === Some(expected)) assert(storageFormat.locationUri === Some(expected)) @@ -1169,18 +1156,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") // set table partition location - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - } + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") verifyLocation("/path/to/part/ways", Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") verifyLocation("/swanky/steak/place") // set table partition location without explicitly specifying database - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - } + sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") verifyLocation("vienna", Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { @@ -1301,6 +1284,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val part2 = Map("a" -> "2", "b" -> "6") val part3 = Map("a" -> "3", "b" -> "7") val part4 = Map("a" -> "4", "b" -> "8") + val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, part1, tableIdent) @@ -1308,40 +1292,40 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + - "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) - } + + // basic add partition + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4)) - } + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (a='4', b='9')") } + // partition to add already exists intercept[AnalysisException] { sql("ALTER TABLE tab1 ADD PARTITION (a='4', b='8')") } - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4)) - } + + // partition to add already exists when using IF NOT EXISTS + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + + // partition spec in ADD PARTITION should be case insensitive by default + sql("ALTER TABLE tab1 ADD PARTITION (A='9', B='9')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4, part5)) } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { @@ -1362,34 +1346,77 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) - } + + // basic drop partition + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + // drop partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - } + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (a='2')") } + // partition to drop does not exist intercept[AnalysisException] { sql("ALTER TABLE tab1 DROP PARTITION (a='300')") } - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") + + // partition to drop does not exist when using IF EXISTS + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + + // partition spec in DROP PARTITION should be case insensitive by default + sql("ALTER TABLE tab1 DROP PARTITION (A='1', B='5')") + assert(catalog.listPartitions(tableIdent).isEmpty) + } + + private def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1", "b" -> "q") + val part2 = Map("a" -> "2", "b" -> "c") + val part3 = Map("a" -> "3", "b" -> "p") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + + // basic rename partition + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") + sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + + // table to alter does not exist + intercept[NoSuchTableException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + + // partition to rename does not exist + intercept[NoSuchPartitionException] { + sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") } + + // partition spec in RENAME PARTITION should be case insensitive by default + sql("ALTER TABLE tab1 PARTITION (A='10', B='p') RENAME TO PARTITION (A='1', B='p')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } test("drop build-in function") { @@ -1608,12 +1635,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - // truncating partitioned data source tables is not supported withTable("rectangles", "rectangles2") { data.write.saveAsTable("rectangles") data.write.partitionBy("length").saveAsTable("rectangles2") + + // not supported since the table is not partitioned assertUnsupported("TRUNCATE TABLE rectangles PARTITION (width=1)") - assertUnsupported("TRUNCATE TABLE rectangles2 PARTITION (width=1)") + + // supported since partitions are stored in the metastore + sql("TRUNCATE TABLE rectangles2 PARTITION (width=1)") + assert(spark.table("rectangles2").collect().isEmpty) } } @@ -1713,4 +1744,28 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(sql("show user functions").count() === 1L) } } + + test("show columns - negative test") { + // When case sensitivity is true, the user supplied database name in table identifier + // should match the supplied database name in case sensitive way. + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempDatabase { db => + val tabName = s"$db.showcolumn" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") + val message = intercept[AnalysisException] { + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + }.getMessage + assert(message.contains("SHOW COLUMNS with conflicting databases")) + } + } + } + } + + test("SPARK-18009 calling toLocalIterator on commands") { + import scala.collection.JavaConverters._ + val df = sql("show databases") + val rows: Seq[Row] = df.toLocalIterator().asScala.toSeq + assert(rows.length > 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala index fa3abd0098f5b..56df1face6364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -28,15 +28,15 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.SharedSQLContext -class FileCatalogSuite extends SharedSQLContext { +class FileIndexSuite extends SharedSQLContext { - test("ListingFileCatalog: leaf files are qualified paths") { + test("InMemoryFileIndex: leaf files are qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") stringToFile(file, "text") val path = new Path(file.getCanonicalPath) - val catalog = new ListingFileCatalog(spark, Seq(path), Map.empty, None) { + val catalog = new InMemoryFileIndex(spark, Seq(path), Map.empty, None) { def leafFilePaths: Seq[Path] = leafFiles.keys.toSeq def leafDirPaths: Seq[Path] = leafDirToChildrenFiles.keys.toSeq } @@ -45,7 +45,7 @@ class FileCatalogSuite extends SharedSQLContext { } } - test("ListingFileCatalog: input paths are converted to qualified paths") { + test("InMemoryFileIndex: input paths are converted to qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") stringToFile(file, "text") @@ -59,31 +59,42 @@ class FileCatalogSuite extends SharedSQLContext { val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) - val catalog1 = new ListingFileCatalog( + val catalog1 = new InMemoryFileIndex( spark, Seq(unqualifiedDirPath), Map.empty, None) assert(catalog1.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) - val catalog2 = new ListingFileCatalog( + val catalog2 = new InMemoryFileIndex( spark, Seq(unqualifiedFilePath), Map.empty, None) assert(catalog2.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) } } - test("ListingFileCatalog: folders that don't exist don't throw exceptions") { + test("InMemoryFileIndex: folders that don't exist don't throw exceptions") { withTempDir { dir => val deletedFolder = new File(dir, "deleted") assert(!deletedFolder.exists()) - val catalog1 = new ListingFileCatalog( + val catalog1 = new InMemoryFileIndex( spark, Seq(new Path(deletedFolder.getCanonicalPath)), Map.empty, None) // doesn't throw an exception - assert(catalog1.listLeafFiles(catalog1.paths).isEmpty) + assert(catalog1.listLeafFiles(catalog1.rootPaths).isEmpty) } } - test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { + test("PartitioningAwareFileIndex - file filtering") { + assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd")) + assert(PartitioningAwareFileIndex.shouldFilterOut(".ab")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_cd")) + assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata")) + assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata")) + } + + test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") { class MockCatalog( - override val paths: Seq[Path]) extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + override val rootPaths: Seq[Path]) + extends PartitioningAwareFileIndex(spark, Map.empty, None) { override def refresh(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c5deb31fec183..d900ce7bb2370 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -393,9 +393,9 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, fileName) } - val fileCatalog = new ListingFileCatalog( + val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - paths = Seq(new Path(tempDir)), + rootPaths = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 3c68dc8bb98d8..89d57653adcbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -39,15 +39,4 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } - - test("file filtering") { - assert(!HadoopFsRelation.shouldFilterOut("abcd")) - assert(HadoopFsRelation.shouldFilterOut(".ab")) - assert(HadoopFsRelation.shouldFilterOut("_cd")) - - assert(!HadoopFsRelation.shouldFilterOut("_metadata")) - assert(!HadoopFsRelation.shouldFilterOut("_common_metadata")) - assert(HadoopFsRelation.shouldFilterOut("_ab_metadata")) - assert(HadoopFsRelation.shouldFilterOut("_cd_common_metadata")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 8d18be9300f7e..120a3a2ef33aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.collection.mutable.ArrayBuffer @@ -30,7 +30,8 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -56,8 +57,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) + val decimal = Decimal("1" * 20) + check("1" * 20, + Literal.create(decimal, DecimalType(decimal.precision, decimal.scale))) check("1.5", Literal.create(1.5, DoubleType)) check("hello", Literal.create("hello", StringType)) + check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType)) + check("1990-02-24 12:00:30", + Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType)) check(defaultPartitionName, Literal.create(null, NullType)) } @@ -626,10 +633,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation, _, _) => - assert(relation.partitionSpec === PartitionSpec.emptySpec) + case LogicalRelation( + HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _) => + assert(location.partitionSpec() === PartitionSpec.emptySpec) }.getOrElse { - fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") + fail(s"Expecting a matching HadoopFsRelation, but got:\n$queryExecution") } } } @@ -687,6 +695,40 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("Various inferred partition value types") { + val row = + Row( + Long.MaxValue, + 4.5, + new java.math.BigDecimal(new BigInteger("1" * 20)), + java.sql.Date.valueOf("2015-05-23"), + java.sql.Timestamp.valueOf("1990-02-24 12:00:30"), + "This is a string, /[]?=:", + "This is not a partition column") + + val partitionColumnTypes = + Seq( + LongType, + DoubleType, + DecimalType(20, 0), + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name)) + checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) + } + } + test("SPARK-8037: Ignores files whose name starts with dot") { withTempPath { dir => val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538f..c3d202ced24c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -1080,6 +1080,34 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + testSchemaClipping( + "falls back to case insensitive resolution", + + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin) + testSchemaClipping( "simple nested struct", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index bba40c6510cfb..229d8814e0143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ @@ -85,6 +86,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("LocalTableScanExec computes metrics in collect and take") { + val df1 = spark.createDataset(Seq(1, 2, 3)) + val logical = df1.queryExecution.logical + require(logical.isInstanceOf[LocalRelation]) + df1.collect() + val metrics1 = df1.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics1.contains("numOutputRows")) + assert(metrics1("numOutputRows").value === 3) + + val df2 = spark.createDataset(Seq(1, 2, 3)).limit(2) + df2.collect() + val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics2.contains("numOutputRows")) + assert(metrics2("numOutputRows").value === 2) + } + test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala index 3e1e1126f9e6b..4a47c04d3f084 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -94,7 +94,7 @@ class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) - val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, dir.getAbsolutePath, Map.empty) // this method should throw an exception if `fs.exists` is called during resolveRelation newSource.getBatch(None, LongOffset(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 7928b8e8775c2..9e059216110f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -23,8 +23,9 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter -import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { @@ -136,7 +137,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } } - test("foreach with error") { + testQuietly("foreach with error") { withTempDir { checkpointDir => val input = MemoryStream[Int] val query = input.toDS().repartition(1).writeStream @@ -148,16 +149,24 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } }).start() input.addData(1, 2, 3, 4) - query.processAllAvailable() + + // Error in `process` should fail the Spark job + val e = intercept[StreamingQueryException] { + query.processAllAvailable() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.getMessage === "error") + assert(query.isActive === false) val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + + // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) assert(errorEvent.error.get.getMessage === "error") - query.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala new file mode 100644 index 0000000000000..938423db64745 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalactic.TolerantNumerics + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.ManualClock + +class StreamMetricsSuite extends SparkFunSuite { + import StreamMetrics._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + + test("rates, latencies, trigger details - basic life cycle") { + val sm = newStreamMetrics(source) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + + // When trigger started, the rates should not change, but should return + // reported trigger details + sm.reportTriggerStarted(1) + sm.reportTriggerDetail("key", "value") + sm.reportSourceTriggerDetail(source, "key2", "value2") + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + START_TIMESTAMP -> "0", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", "key2" -> "value2")) + + // Finishing the trigger should calculate the rates, except input rate which needs + // to have another trigger interval + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows, 10 output rows + clock.advance(1000) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) // 100 input rows processed in 1 sec + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", + NUM_INPUT_ROWS -> "100", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + + // After another trigger starts, the rates and latencies should not change until + // new rows are reported + clock.advance(1000) + sm.reportTriggerStarted(2) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + + // Reporting new rows should update the rates and latencies + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source) === 100.0) + assert(sm.currentSourceProcessingRate(source) === 400.0) + assert(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Rates should be set to 0 after stop + sm.stop() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + } + + test("rates and latencies - after trigger with no data") { + val sm = newStreamMetrics(source) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + require(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + require(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + require(sm.currentSourceInputRate(source) === 100.0) + require(sm.currentSourceProcessingRate(source) === 400.0) + require(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + sm.stop() + } + + test("rates - after trigger with multiple sources, and one source having no info") { + val source1 = TestSource(1) + val source2 = TestSource(2) + val sm = newStreamMetrics(source1, source2) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source1 -> 100L, source2 -> 100L)) + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source1 -> 200L, source2 -> 200L)) + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + assert(sm.currentInputRate() === 200.0) // 200*2 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 800.0) // 200*2 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source1) === 100.0) + assert(sm.currentSourceInputRate(source2) === 100.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 400.0) + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportNumInputRows(Map(source1 -> 200L)) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 200.0) + assert(sm.currentProcessingRate() === 400.0) + assert(sm.currentSourceInputRate(source1) === 200.0) + assert(sm.currentSourceInputRate(source2) === 0.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 0.0) + sm.stop() + } + + test("registered Codahale metrics") { + import scala.collection.JavaConverters._ + val sm = newStreamMetrics(source) + val gaugeNames = sm.metricRegistry.getGauges().keySet().asScala + + // so that all metrics are considered as a single metric group in Ganglia + assert(!gaugeNames.exists(_.contains("."))) + assert(gaugeNames === Set( + "inputRate-total", + "inputRate-source0", + "processingRate-total", + "processingRate-source0", + "latency")) + } + + private def newStreamMetrics(sources: Source*): StreamMetrics = { + new StreamMetrics(sources.toSet, clock, "test") + } + + private val clock = new ManualClock() + private val source = TestSource(0) + + case class TestSource(id: Int) extends Source { + override def schema: StructType = StructType(Array.empty[StructField]) + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { null } + override def stop() {} + override def toString(): String = s"source$id" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 6b0ba7acb4804..5174a0415304c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -156,6 +156,30 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val batch = source.getBatch(None, source.getOffset.get).as[String] + batch.collect() + val numRowsMetric = + batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + assert(numRowsMetric.nonEmpty) + assert(numRowsMetric.get.value === 1) + source.stop() + source = null + } + } + private class ServerThread extends Thread with Logging { private val serverSocket = new ServerSocket(0) private val messageQueue = new LinkedBlockingQueue[String]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 984b84fd13fbd..fcf300b3c81bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -74,6 +74,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify state after updating put(store, "a", 1) + assert(store.numKeys() === 1) intercept[IllegalStateException] { store.iterator() } @@ -85,7 +86,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Make updates, commit and then verify state put(store, "b", 2) put(store, "aa", 3) + assert(store.numKeys() === 3) remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) assert(store.commit() === 1) assert(store.hasCommitted) @@ -107,7 +110,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -362,7 +367,10 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val conf = new SparkConf() .setMaster("local") .setAppName("test") + // Make maintenance thread do snapshots and cleanups very fast .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' + // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString @@ -372,37 +380,49 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = new HDFSBackedStateStoreProvider( storeId, keySchema, valueSchema, storeConf, hadoopConf) + var latestStoreVersion = 0 + + def generateStoreVersions() { + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + put(store, "a", i) + store.commit() + latestStoreVersion += 1 + } + } quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") - for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - put(store, "a", i) - store.commit() - } + // Generate sufficient versions of store for snapshots + generateStoreVersions() eventually(timeout(10 seconds)) { + // Store should have been reported to the coordinator assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") - } - // Background maintenance should clean up and generate snapshots - assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") - - eventually(timeout(10 seconds)) { - // Earliest delta file should get cleaned up - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") // Some snapshots should have been generated - val snapshotVersions = (0 to 20).filter { version => + val snapshotVersions = (1 to latestStoreVersion).filter { version => fileExists(provider, version, isSnapshot = true) } assert(snapshotVersions.nonEmpty, "no snapshot file found") } + // Generate more versions such that there is another snapshot and + // the earliest delta file will be cleaned up + generateStoreVersions() + + // Earliest delta file should get cleaned up + eventually(timeout(10 seconds)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } + // If driver decides to deactivate all instances of the store, then this instance // should be unloaded coordinatorRef.deactivateInstances(dir) @@ -411,7 +431,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -421,14 +441,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } // Verify if instance is unloaded if SparkContext is stopped - require(SparkEnv.get === null) eventually(timeout(10 seconds)) { + require(SparkEnv.get === null) assert(!StateStore.isLoaded(storeId)) assert(!StateStore.isMaintenanceRunning) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index df640ffab91de..11d4693f1c2a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.internal import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} +import org.apache.spark.util.Utils class SQLConfSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -215,12 +215,15 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("default value of WAREHOUSE_PATH") { + val original = spark.conf.get(SQLConf.WAREHOUSE_PATH) try { // to get the default value, always unset it spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) - assert(spark.sessionState.conf.warehousePath - === new Path(s"${System.getProperty("user.dir")}/spark-warehouse").toString) + // JVM adds a trailing slash if the directory exists and leaves it as-is, if it doesn't + // In our comparison, strip trailing slash off of both sides, to account for such cases + assert(new Path(Utils.resolveURI("spark-warehouse")).toString.stripSuffix("/") === spark + .sessionState.conf.warehousePath.stripSuffix("/")) } finally { sql(s"set ${SQLConf.WAREHOUSE_PATH}=$original") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index c39005f6a1063..5cc9467395adc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -238,7 +238,7 @@ class CreateTableAsSelectSuite } } - test("CTAS of decimal calculation") { + test("SPARK-17409: CTAS of decimal calculation") { withTable("tab2") { withTempView("tab1") { spark.range(99, 101).createOrReplaceTempView("tab1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5eb54643f204f..4a85b5975ea53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -185,6 +185,48 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } + test("INSERT INTO TABLE with Comment in columns") { + val tabName = "tab1" + withTable(tabName) { + sql( + s""" + |CREATE TABLE $tabName(col1 int COMMENT 'a', col2 int) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tabName SELECT 1, 2") + + checkAnswer( + sql(s"SELECT col1, col2 FROM $tabName"), + Row(1, 2) :: Nil + ) + } + } + + test("INSERT INTO TABLE - complex type but different names") { + val tab1 = "tab1" + val tab2 = "tab2" + withTable(tab1, tab2) { + sql( + s""" + |CREATE TABLE $tab1 (s struct) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tab1 SELECT named_struct('col1','1','col2','2')") + + sql( + s""" + |CREATE TABLE $tab2 (p struct) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tab2 SELECT * FROM $tab1") + + checkAnswer( + spark.table(tab1), + spark.table(tab2) + ) + } + } + test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 19c89f5c4100c..18b42a81a098c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileCatalog} +import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileIndex} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -179,14 +179,14 @@ class FileStreamSinkSuite extends StreamTest { .add(StructField("id", IntegerType)) assert(outputDf.schema === expectedSchema) - // Verify that MetadataLogFileCatalog is being used and the correct partitioning schema has + // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => baseRelation.asInstanceOf[HadoopFsRelation] } assert(hadoopdFsRelations.size === 1) - assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileCatalog]) + assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == "id")) assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 7f9c981a4e9c9..47018b3a3c495 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -664,7 +664,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest { def createFile(content: String, src: File, tmp: File): Unit = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) val finalFile = new File(src, tempFile.getName) - src.mkdirs() + require(!src.exists(), s"$src exists, dir: ${src.isDirectory}, file: ${src.isFile}") + require(src.mkdirs(), s"Cannot create $src") + require(src.isDirectory(), s"$src is not a directory") require(stringToFile(tempFile, content).renameTo(finalFile)) } @@ -877,7 +879,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val numFiles = 10000 // This is to avoid running a spark job to list of files in parallel - // by the ListingFileCatalog. + // by the InMemoryFileIndex. spark.sessionState.conf.setConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD, numFiles * 2) withTempDirs { case (root, tmp) => @@ -998,6 +1000,20 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } } + + test("input row metrics") { + withTempDirs { case (src, tmp) => + val input = spark.readStream.format("text").load(src.getCanonicalPath) + testStream(input)( + AddTextFileData("100", src, tmp), + CheckAnswer("100"), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total") === "1") + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index cdbad901dba8e..6bdf47901ae68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -161,7 +161,7 @@ class StreamSuite extends StreamTest { val inputData = MemoryStream[Int] testStream(inputData.toDS())( - StartStream(ProcessingTime("10 seconds"), new ManualClock), + StartStream(ProcessingTime("10 seconds"), new StreamManualClock), /* -- batch 0 ----------------------- */ // Add some data in batch 0 @@ -199,7 +199,7 @@ class StreamSuite extends StreamTest { /* Stop then restart the Stream */ StopStream, - StartStream(ProcessingTime("10 seconds"), new ManualClock), + StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), /* -- batch 1 rerun ----------------- */ // this batch 1 would re-run because the latest batch id logged in offset log is 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index fa13d385cce75..742833065144d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -28,6 +28,8 @@ import scala.util.control.NonFatal import org.scalatest.Assertions import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -38,6 +40,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -198,6 +201,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + case class AssertOnLastQueryStatus(condition: StreamingQueryStatus => Unit) + extends StreamAction + + class StreamManualClock(time: Long = 0L) extends ManualClock(time) { + private var waitStartTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + } + } + + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + } + + /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -299,12 +323,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val testThread = Thread.currentThread() val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - + val statusCollector = new QueryStatusCollector + var manualClockExpectedTime = -1L try { + spark.streams.addListener(statusCollector) startedTest.foreach { action => + logInfo(s"Processing test stream action: $action") action match { case StartStream(trigger, triggerClock) => verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } lastStream = currentStream currentStream = spark @@ -328,14 +361,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case AdvanceManualClock(timeToAdd) => verify(currentStream != null, "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[ManualClock], + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[ManualClock] + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("ManualClock has not yet entered the waiting state") { - assert(clock.isWaiting) + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) } - currentStream.triggerClock.asInstanceOf[ManualClock].advance(timeToAdd) + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") case StopStream => verify(currentStream != null, "can not stop a stream that is not running") @@ -399,6 +437,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val streamToAssert = Option(currentStream).getOrElse(lastStream) verify({ a.run(); true }, s"Assert failed: ${a.message}") + case a: AssertOnLastQueryStatus => + Eventually.eventually(timeout(streamingTimeout)) { + require(statusCollector.lastTriggerStatus.nonEmpty) + } + val status = statusCollector.lastTriggerStatus.get + verify({ a.condition(status); true }, "Assert on last query status failed") + case a: AddData => try { // Add data and get the source where it was added, and the expected offset of the @@ -473,6 +518,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() } + spark.streams.removeListener(statusCollector) } } @@ -606,4 +652,58 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } } + + + class QueryStatusCollector extends StreamingQueryListener { + // to catch errors in the async listener events + @volatile private var asyncTestWaiter = new Waiter + + @volatile var startStatus: StreamingQueryStatus = null + @volatile var terminationStatus: StreamingQueryStatus = null + @volatile var terminationException: Option[String] = null + + private val progressStatuses = new mutable.ArrayBuffer[StreamingQueryStatus] + + /** Get the info of the last trigger that processed data */ + def lastTriggerStatus: Option[StreamingQueryStatus] = synchronized { + progressStatuses.filter { i => + i.triggerDetails.get("isTriggerActive").toBoolean == false && + i.triggerDetails.get("isDataPresentInTrigger").toBoolean == true + }.lastOption + } + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + asyncTestWaiter = new Waiter + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(10 seconds)) + } + + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + asyncTestWaiter { + startStatus = queryStarted.queryStatus + } + } + + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + synchronized { progressStatuses += queryProgress.queryStatus } + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = queryTerminated.queryStatus + terminationException = queryTerminated.exception + } + asyncTestWaiter.dismiss() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 8681199817fe6..e59b5491f90b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed @@ -129,6 +130,59 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { ) } + test("state metrics") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDS() + .flatMap(x => Seq(x, x + 1)) + .toDF("value") + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + implicit class RichStreamExecution(query: StreamExecution) { + def stateNodes: Seq[SparkPlan] = { + query.lastExecution.executedPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + } + } + + // Test with Update mode + testStream(aggregated, Update)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + + // Test with Complete mode + inputData.reset() + testStream(aggregated, Complete)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((1, 1), (2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 4 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + } + test("multiple keys") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 831543a47420a..cebb32a0a56cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -17,92 +17,100 @@ package org.apache.spark.sql.streaming -import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.mutable +import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ -import org.scalatest.concurrent.AsyncAssertions.Waiter -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException +import org.apache.spark.scheduler._ +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.util.JsonProtocol +import org.apache.spark.sql.functions._ +import org.apache.spark.util.{JsonProtocol, ManualClock} class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { import testImplicits._ - import StreamingQueryListener._ + import StreamingQueryListenerSuite._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) after { spark.streams.active.foreach(_.stop()) assert(spark.streams.active.isEmpty) assert(addedListeners.isEmpty) // Make sure we don't leak any events to the next test - spark.sparkContext.listenerBus.waitUntilEmpty(10000) } - test("single listener") { - val listener = new QueryStatusCollector - val input = MemoryStream[Int] - withListenerAdded(listener) { - testStream(input.toDS)( - StartStream(), - AssertOnQuery("Incorrect query status in onQueryStarted") { query => - val status = listener.startStatus - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses.size === 1) - assert(status.sourceStatuses(0).description.contains("Memory")) - - // The source and sink offsets must be None as this must be called before the - // batches have started - assert(status.sourceStatuses(0).offsetDesc === None) - assert(status.sinkStatus.offsetDesc === CompositeOffset(None :: Nil).toString) - - // No progress events or termination events - assert(listener.progressStatuses.isEmpty) - assert(listener.terminationStatus === null) - true - }, - AddDataMemory(input, Seq(1, 2, 3)), - CheckAnswer(1, 2, 3), - AssertOnQuery("Incorrect query status in onQueryProgress") { query => - eventually(Timeout(streamingTimeout)) { - - // There should be only on progress event as batch has been processed - assert(listener.progressStatuses.size === 1) - val status = listener.progressStatuses.peek() - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)) - assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString) - - // No termination events - assert(listener.terminationStatus === null) - } - true - }, - StopStream, - AssertOnQuery("Incorrect query status in onQueryTerminated") { query => - eventually(Timeout(streamingTimeout)) { - val status = listener.terminationStatus - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)) - assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString) - assert(listener.terminationException === None) - } - listener.checkAsyncErrors() - true + test("single listener, check trigger statuses") { + import StreamingQueryListenerSuite._ + clock = new StreamManualClock + + /** Custom MemoryStream that waits for manual clock to reach a time */ + val inputData = new MemoryStream[Int](0, sqlContext) { + // Wait for manual clock to be 100 first time there is data + override def getOffset: Option[Offset] = { + val offset = super.getOffset + if (offset.nonEmpty) { + clock.waitTillTime(100) } - ) + offset + } + + // Wait for manual clock to be 300 first time there is data + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + clock.waitTillTime(300) + super.getBatch(start, end) + } + } + + // This is to make sure thatquery waits for manual clock to be 600 first time there is data + val mapped = inputData.toDS().agg(count("*")).as[Long].coalesce(1).map { x => + clock.waitTillTime(600) + x } + + testStream(mapped, OutputMode.Complete)( + StartStream(triggerClock = clock), + AddData(inputData, 1, 2), + AdvanceManualClock(100), // unblock getOffset, will block on getBatch + AdvanceManualClock(200), // unblock getBatch, will block on computation + AdvanceManualClock(300), // unblock computation + AssertOnQuery { _ => clock.getTimeMillis() === 600 }, + AssertOnLastQueryStatus { status: StreamingQueryStatus => + // Check the correctness of the trigger info of the last completed batch reported by + // onQueryProgress + assert(status.triggerDetails.containsKey("triggerId")) + assert(status.triggerDetails.get("isTriggerActive") === "false") + assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") + + assert(status.triggerDetails.get("timestamp.triggerStart") === "0") + assert(status.triggerDetails.get("timestamp.afterGetOffset") === "100") + assert(status.triggerDetails.get("timestamp.afterGetBatch") === "300") + assert(status.triggerDetails.get("timestamp.triggerFinish") === "600") + + assert(status.triggerDetails.get("latency.getOffset.total") === "100") + assert(status.triggerDetails.get("latency.getBatch.total") === "200") + assert(status.triggerDetails.get("latency.optimizer") === "0") + assert(status.triggerDetails.get("latency.offsetLogWrite") === "0") + assert(status.triggerDetails.get("latency.fullTrigger") === "600") + + assert(status.triggerDetails.get("numRows.input.total") === "2") + assert(status.triggerDetails.get("numRows.state.aggregation1.total") === "1") + assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") + + assert(status.sourceStatuses.length === 1) + assert(status.sourceStatuses(0).triggerDetails.containsKey("triggerId")) + assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") + assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") + }, + CheckAnswer(2) + ) } test("adding and removing listener") { @@ -172,56 +180,77 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } test("QueryStarted serialization") { - val queryStartedInfo = new StreamingQueryInfo( - "name", - 1, - Seq(new SourceStatus("source1", None), new SourceStatus("source2", None)), - new SinkStatus("sink", CompositeOffset(None :: None :: Nil).toString)) - val queryStarted = new StreamingQueryListener.QueryStarted(queryStartedInfo) + val queryStarted = new StreamingQueryListener.QueryStartedEvent(StreamingQueryStatus.testStatus) val json = JsonProtocol.sparkEventToJson(queryStarted) val newQueryStarted = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryStarted] - assertStreamingQueryInfoEquals(queryStarted.queryInfo, newQueryStarted.queryInfo) + .asInstanceOf[StreamingQueryListener.QueryStartedEvent] + assertStreamingQueryInfoEquals(queryStarted.queryStatus, newQueryStarted.queryStatus) } test("QueryProgress serialization") { - val queryProcessInfo = new StreamingQueryInfo( - "name", - 1, - Seq( - new SourceStatus("source1", Some(LongOffset(0).toString)), - new SourceStatus("source2", Some(LongOffset(1).toString))), - new SinkStatus("sink", new CompositeOffset(Array(None, Some(LongOffset(1)))).toString)) - val queryProcess = new StreamingQueryListener.QueryProgress(queryProcessInfo) + val queryProcess = new StreamingQueryListener.QueryProgressEvent( + StreamingQueryStatus.testStatus) val json = JsonProtocol.sparkEventToJson(queryProcess) val newQueryProcess = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryProgress] - assertStreamingQueryInfoEquals(queryProcess.queryInfo, newQueryProcess.queryInfo) + .asInstanceOf[StreamingQueryListener.QueryProgressEvent] + assertStreamingQueryInfoEquals(queryProcess.queryStatus, newQueryProcess.queryStatus) } test("QueryTerminated serialization") { - val queryTerminatedInfo = new StreamingQueryInfo( - "name", - 1, - Seq( - new SourceStatus("source1", Some(LongOffset(0).toString)), - new SourceStatus("source2", Some(LongOffset(1).toString))), - new SinkStatus("sink", new CompositeOffset(Array(None, Some(LongOffset(1)))).toString)) val exception = new RuntimeException("exception") - val queryQueryTerminated = new StreamingQueryListener.QueryTerminated( - queryTerminatedInfo, + val queryQueryTerminated = new StreamingQueryListener.QueryTerminatedEvent( + StreamingQueryStatus.testStatus, Some(exception.getMessage)) val json = JsonProtocol.sparkEventToJson(queryQueryTerminated) val newQueryTerminated = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryTerminated] - assertStreamingQueryInfoEquals(queryQueryTerminated.queryInfo, newQueryTerminated.queryInfo) + .asInstanceOf[StreamingQueryListener.QueryTerminatedEvent] + assertStreamingQueryInfoEquals(queryQueryTerminated.queryStatus, newQueryTerminated.queryStatus) assert(queryQueryTerminated.exception === newQueryTerminated.exception) } + test("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") { + // query-event-logs-version-2.0.0.txt has all types of events generated by + // Structured Streaming in Spark 2.0.0. + // SparkListenerApplicationEnd is the only valid event and it's the last event. We use it + // to verify that we can skip broken jsons generated by Structured Streaming. + testReplayListenerBusWithBorkenEventJsons("query-event-logs-version-2.0.0.txt") + } + + test("ReplayListenerBus should ignore broken event jsons generated in 2.0.1") { + // query-event-logs-version-2.0.1.txt has all types of events generated by + // Structured Streaming in Spark 2.0.1. + // SparkListenerApplicationEnd is the only valid event and it's the last event. We use it + // to verify that we can skip broken jsons generated by Structured Streaming. + testReplayListenerBusWithBorkenEventJsons("query-event-logs-version-2.0.1.txt") + } + + private def testReplayListenerBusWithBorkenEventJsons(fileName: String): Unit = { + val input = getClass.getResourceAsStream(s"/structured-streaming/$fileName") + val events = mutable.ArrayBuffer[SparkListenerEvent]() + try { + val replayer = new ReplayListenerBus() { + // Redirect all parsed events to `events` + override def doPostEvent( + listener: SparkListenerInterface, + event: SparkListenerEvent): Unit = { + events += event + } + } + // Add a dummy listener so that "doPostEvent" will be called. + replayer.addListener(new SparkListener {}) + replayer.replay(input, fileName) + // SparkListenerApplicationEnd is the only valid event + assert(events.size === 1) + assert(events(0).isInstanceOf[SparkListenerApplicationEnd]) + } finally { + input.close() + } + } + private def assertStreamingQueryInfoEquals( - expected: StreamingQueryInfo, - actual: StreamingQueryInfo): Unit = { + expected: StreamingQueryStatus, + actual: StreamingQueryStatus): Unit = { assert(expected.name === actual.name) assert(expected.sourceStatuses.size === actual.sourceStatuses.size) expected.sourceStatuses.zip(actual.sourceStatuses).foreach { @@ -243,7 +272,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = { try { - failAfter(1 minute) { + failAfter(streamingTimeout) { spark.streams.addListener(listener) body } @@ -258,49 +287,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val listenerBus = spark.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener]) } +} - class QueryStatusCollector extends StreamingQueryListener { - // to catch errors in the async listener events - @volatile private var asyncTestWaiter = new Waiter - - @volatile var startStatus: StreamingQueryInfo = null - @volatile var terminationStatus: StreamingQueryInfo = null - @volatile var terminationException: Option[String] = null - - val progressStatuses = new ConcurrentLinkedQueue[StreamingQueryInfo] - - def reset(): Unit = { - startStatus = null - terminationStatus = null - progressStatuses.clear() - asyncTestWaiter = new Waiter - } - - def checkAsyncErrors(): Unit = { - asyncTestWaiter.await(timeout(streamingTimeout)) - } - - - override def onQueryStarted(queryStarted: QueryStarted): Unit = { - asyncTestWaiter { - startStatus = queryStarted.queryInfo - } - } - - override def onQueryProgress(queryProgress: QueryProgress): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryProgress called before onQueryStarted") - progressStatuses.add(queryProgress.queryInfo) - } - } - - override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryTerminated called before onQueryStarted") - terminationStatus = queryTerminated.queryInfo - terminationException = queryTerminated.exception - } - asyncTestWaiter.dismiss() - } - } +object StreamingQueryListenerSuite { + // Singleton reference to clock that does not get serialized in task closures + @volatile var clock: ManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala new file mode 100644 index 0000000000000..1a98cf2ba74e6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite + +class StreamingQueryStatusSuite extends SparkFunSuite { + test("toString") { + assert(StreamingQueryStatus.testStatus.sourceStatuses(0).toString === + """ + |Status of source MySource1 + | Available offset: #0 + | Input rate: 15.5 rows/sec + | Processing rate: 23.5 rows/sec + | Trigger details: + | numRows.input.source: 100 + | latency.getOffset.source: 10 + | latency.getBatch.source: 20 + """.stripMargin.trim, "SourceStatus.toString does not match") + + assert(StreamingQueryStatus.testStatus.sinkStatus.toString === + """ + |Status of sink MySink + | Committed offsets: [#1, -] + """.stripMargin.trim, "SinkStatus.toString does not match") + + assert(StreamingQueryStatus.testStatus.toString === + """ + |Status of query 'query' + | Query id: 1 + | Status timestamp: 123 + | Input rate: 15.5 rows/sec + | Processing rate 23.5 rows/sec + | Latency: 345.0 ms + | Trigger details: + | isDataPresentInTrigger: true + | isTriggerActive: true + | latency.getBatch.total: 20 + | latency.getOffset.total: 10 + | numRows.input.total: 100 + | triggerId: 5 + | Source statuses [1 source]: + | Source 1 - MySource1 + | Available offset: #0 + | Input rate: 15.5 rows/sec + | Processing rate: 23.5 rows/sec + | Trigger details: + | numRows.input.source: 100 + | latency.getOffset.source: 10 + | latency.getBatch.source: 20 + | Sink status - MySink + | Committed offsets: [#1, -] + """.stripMargin.trim, "StreamingQueryStatus.toString does not match") + + } + + test("json") { + assert(StreamingQueryStatus.testStatus.json === + """ + |{"sourceStatuses":[{"description":"MySource1","offsetDesc":"#0","inputRate":15.5, + |"processingRate":23.5,"triggerDetails":{"numRows.input.source":"100", + |"latency.getOffset.source":"10","latency.getBatch.source":"20"}}], + |"sinkStatus":{"description":"MySink","offsetDesc":"[#1, -]"}} + """.stripMargin.replace("\n", "").trim) + } + + test("prettyJson") { + assert( + StreamingQueryStatus.testStatus.prettyJson === + """ + |{ + | "sourceStatuses" : [ { + | "description" : "MySource1", + | "offsetDesc" : "#0", + | "inputRate" : 15.5, + | "processingRate" : 23.5, + | "triggerDetails" : { + | "numRows.input.source" : "100", + | "latency.getOffset.source" : "10", + | "latency.getBatch.source" : "20" + | } + | } ], + | "sinkStatus" : { + | "description" : "MySink", + | "offsetDesc" : "[#1, -]" + | } + |} + """.stripMargin.trim) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 88f1f188ab2af..464c443beb6e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,18 +17,27 @@ package org.apache.spark.sql.streaming +import org.scalactic.TolerantNumerics +import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.Logging +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.types.StructType import org.apache.spark.SparkException -import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.util.Utils -class StreamingQuerySuite extends StreamTest with BeforeAndAfter { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { import AwaitTerminationTester._ import testImplicits._ + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + after { sqlContext.streams.active.foreach(_.stop()) } @@ -100,37 +109,151 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter { ) } - testQuietly("source and sink statuses") { + testQuietly("query statuses") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) - testStream(mapped)( - AssertOnQuery(_.sourceStatuses.length === 1), + AssertOnQuery(q => q.status.name === q.name), + AssertOnQuery(q => q.status.id === q.id), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === CompositeOffset(None :: Nil).toString), AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === None), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), AssertOnQuery(_.sinkStatus.description.contains("Memory")), AssertOnQuery(_.sinkStatus.offsetDesc === new CompositeOffset(None :: Nil).toString), + AddData(inputData, 1, 2), CheckAnswer(6, 3), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate >= 0.0), + AssertOnQuery(_.status.processingRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate >= 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(0)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate >= 0.0), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString), + AddData(inputData, 1, 2), CheckAnswer(6, 3, 6, 3), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(1).toString)), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + + StopStream, + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.status.triggerDetails.isEmpty), + + StartStream(), AddData(inputData, 0), ExpectFailure[SparkException], - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(2).toString)), + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString) ) } + test("codahale metrics") { + val inputData = MemoryStream[Int] + + /** Whether metrics of a query is registered for reporting */ + def isMetricsRegistered(query: StreamingQuery): Boolean = { + val sourceName = s"StructuredStreaming.${query.name}" + val sources = spark.sparkContext.env.metricsSystem.getSourcesByName(sourceName) + require(sources.size <= 1) + sources.nonEmpty + } + // Disabled by default + assert(spark.conf.get("spark.sql.streaming.metricsEnabled").toBoolean === false) + + withSQLConf("spark.sql.streaming.metricsEnabled" -> "false") { + testStream(inputData.toDF)( + AssertOnQuery { q => !isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + + // Registered when enabled + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + testStream(inputData.toDF)( + AssertOnQuery { q => isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + } + + test("input row calculation with mixed batch and streaming sources") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + // Trigger input has 10 rows, static input has 2 rows, + // therefore after the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF.join(staticInputDF, "value")) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + + test("input row calculation with trigger DF having multiple leaves") { + val streamingTriggerDF = + spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) + require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF) + + // After the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) - // Run 3 batches, and then assert that only 1 metadata file is left at the end - // since the first 2 should have been purged. + // Run 3 batches, and then assert that only 2 metadata files is are at the end + // since the first should have been purged. testStream(mapped)( AddData(inputData, 1, 2), CheckAnswer(6, 3), @@ -139,16 +262,55 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter { AddData(inputData, 4, 6), CheckAnswer(6, 3, 6, 3, 1, 1), - AssertOnQuery("metadata log should contain only one file") { q => + AssertOnQuery("metadata log should contain only two files") { q => val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) - val toTest = logFileNames.filter(! _.endsWith(".crc")) // Workaround for SPARK-17475 - assert(toTest.size == 1 && toTest.head == "2") + val toTest = logFileNames.filter(! _.endsWith(".crc")).sorted // Workaround for SPARK-17475 + assert(toTest.size == 2 && toTest.head == "1") true } ) } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ + private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { + require(!triggerDF.isStreaming) + // A streaming Source that generate only on trigger and returns the given Dataframe as batch + val source = new Source() { + override def schema: StructType = triggerDF.schema + override def getOffset: Option[Offset] = Some(LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF + override def stop(): Unit = {} + } + StreamingExecutionRelation(source) + } + + /** Returns the query status at the end of the first trigger of streaming DF */ + private def getFirstTriggerStatus(streamingDF: DataFrame): StreamingQueryStatus = { + // A StreamingQueryListener that gets the query status after the first completed trigger + val listener = new StreamingQueryListener { + @volatile var firstStatus: StreamingQueryStatus = null + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + if (firstStatus == null) firstStatus = queryProgress.queryStatus + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { } + } + + try { + spark.streams.addListener(listener) + val q = streamingDF.writeStream.format("memory").queryName("test").start() + q.processAllAvailable() + eventually(timeout(streamingTimeout)) { + assert(listener.firstStatus != null) + } + listener.firstStatus + } finally { + spark.streams.active.map(_.stop()) + spark.streams.removeListener(listener) + } + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 6a5117aea492d..226b7e175a9d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -79,6 +79,9 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: sqlContext.newSession() } ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) + if (sessionConf != null && sessionConf.containsKey("use:database")) { + ctx.sql(s"use ${sessionConf.get("use:database")}") + } sparkSqlOperationManager.sessionToContexts.put(sessionHandle, ctx) sessionHandle } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala new file mode 100644 index 0000000000000..fb8a7e273ae44 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.DriverManager + +import org.apache.hive.jdbc.HiveDriver + +import org.apache.spark.util.Utils + +class JdbcConnectionUriSuite extends HiveThriftServer2Test { + Utils.classForName(classOf[HiveDriver].getCanonicalName) + + override def mode: ServerMode.Value = ServerMode.binary + + val JDBC_TEST_DATABASE = "jdbc_test_database" + val USER = System.getProperty("user.name") + val PASSWORD = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"CREATE DATABASE $JDBC_TEST_DATABASE") + connection.close() + } + + override protected def afterAll(): Unit = { + try { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"DROP DATABASE $JDBC_TEST_DATABASE") + connection.close() + } finally { + super.afterAll() + } + } + + test("SPARK-17819 Support default database in connection URIs") { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/$JDBC_TEST_DATABASE" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + try { + val resultSet = statement.executeQuery("select current_database()") + resultSet.next() + assert(resultSet.getString(1) === JDBC_TEST_DATABASE) + } finally { + statement.close() + connection.close() + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 237b829da882f..409c316c6802c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -29,16 +29,18 @@ import org.apache.thrift.TException import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -104,13 +106,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * metastore. */ private def verifyTableProperties(table: CatalogTable): Unit = { - val invalidKeys = table.properties.keys.filter { key => - key.startsWith(DATASOURCE_PREFIX) || key.startsWith(STATISTICS_PREFIX) - } + val invalidKeys = table.properties.keys.filter(_.startsWith(SPARK_SQL_PREFIX)) if (invalidKeys.nonEmpty) { throw new AnalysisException(s"Cannot persistent ${table.qualifiedName} into hive metastore " + - s"as table property keys may not start with '$DATASOURCE_PREFIX' or '$STATISTICS_PREFIX':" + - s" ${invalidKeys.mkString("[", ", ", "]")}") + s"as table property keys may not start with '$SPARK_SQL_PREFIX': " + + invalidKeys.mkString("[", ", ", "]")) } // External users are not allowed to set/switch the table type. In Hive metastore, the table // type can be switched by changing the value of a case-sensitive table property `EXTERNAL`. @@ -189,11 +189,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } // Before saving data source table metadata into Hive metastore, we should: - // 1. Put table schema, partition column names and bucket specification in table properties. + // 1. Put table provider, schema, partition column names, bucket specification and partition + // provider in table properties. // 2. Check if this table is hive compatible // 2.1 If it's not hive compatible, set schema, partition columns and bucket spec to empty // and save table metadata to Hive. - // 2.1 If it's hive compatible, set serde information in table metadata and try to save + // 2.2 If it's hive compatible, set serde information in table metadata and try to save // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 if (DDLUtils.isDatasourceTable(tableDefinition)) { // data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`. @@ -203,6 +204,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put(DATASOURCE_PROVIDER, provider) + if (tableDefinition.partitionProviderIsHive) { + tableProperties.put(TABLE_PARTITION_PROVIDER, "hive") + } // Serialized JSON schema string may be too long to be stored into a single metastore table // property. In this case, we split the JSON string and store each part as a separate table @@ -240,12 +244,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // converts the table metadata to Spark SQL specific format, i.e. set schema, partition column - // names and bucket specification to empty. + // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and + // bucket specification to empty. Note that partition columns are retained, so that we can + // call partition-related Hive API later. def newSparkSQLSpecificMetastoreTable(): CatalogTable = { tableDefinition.copy( - schema = new StructType, - partitionColumnNames = Nil, + schema = tableDefinition.partitionSchema, bucketSpec = None, properties = tableDefinition.properties ++ tableProperties) } @@ -418,12 +422,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, // to retain the spark specific format if it is. Also add old data source properties to table // properties, to retain the data source table format. - val oldDataSourceProps = oldDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) + val oldDataSourceProps = oldDef.properties.filter(_._1.startsWith(SPARK_SQL_PREFIX)) + val partitionProviderProp = if (tableDefinition.partitionProviderIsHive) { + TABLE_PARTITION_PROVIDER -> "hive" + } else { + TABLE_PARTITION_PROVIDER -> "builtin" + } val newDef = withStatsProps.copy( schema = oldDef.schema, partitionColumnNames = oldDef.partitionColumnNames, bucketSpec = oldDef.bucketSpec, - properties = oldDataSourceProps ++ withStatsProps.properties) + properties = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp) client.alterTable(newDef) } else { @@ -447,7 +456,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * properties, and filter out these special entries from table properties. */ private def restoreTableMetadata(table: CatalogTable): CatalogTable = { - val catalogTable = if (table.tableType == VIEW) { + if (conf.get(DEBUG_MODE)) { + return table + } + + val tableWithSchema = if (table.tableType == VIEW) { table } else { getProviderFromTableProperties(table).map { provider => @@ -466,41 +479,38 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } else { table.storage } - val tableProps = if (conf.get(DEBUG_MODE)) { - table.properties - } else { - getOriginalTableProperties(table) - } table.copy( storage = storage, schema = getSchemaFromTableProperties(table), provider = Some(provider), partitionColumnNames = getPartitionColumnsFromTableProperties(table), bucketSpec = getBucketSpecFromTableProperties(table), - properties = tableProps) + partitionProviderIsHive = table.properties.get(TABLE_PARTITION_PROVIDER) == Some("hive")) } getOrElse { - table.copy(provider = Some("hive")) + table.copy(provider = Some("hive"), partitionProviderIsHive = true) } } + // construct Spark's statistics from information in Hive metastore - val statsProps = catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - if (statsProps.nonEmpty) { + val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + val tableWithStats = if (statsProps.nonEmpty) { val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } - val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { + val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => val numFields = ColumnStatStruct.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap - catalogTable.copy( - properties = removeStatsProperties(catalogTable), + tableWithSchema.copy( stats = Some(Statistics( - sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), - rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)), + rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), colStats = colStats))) } else { - catalogTable + tableWithSchema } + + tableWithStats.copy(properties = getOriginalTableProperties(table)) } override def tableExists(db: String, table: String): Boolean = withClient { @@ -585,13 +595,30 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Partitions // -------------------------------------------------------------------------- + // Hive metastore is not case preserving and the partition columns are always lower cased. We need + // to lower case the column names in partition specification before calling partition related Hive + // APIs, to match this behaviour. + private def lowerCasePartitionSpec(spec: TablePartitionSpec): TablePartitionSpec = { + spec.map { case (k, v) => k.toLowerCase -> v } + } + + // Hive metastore is not case preserving and the column names of the partition specification we + // get from the metastore are always lower cased. We should restore them w.r.t. the actual table + // partition columns. + private def restorePartitionSpec( + spec: TablePartitionSpec, + partCols: Seq[String]): TablePartitionSpec = { + spec.map { case (k, v) => partCols.find(_.equalsIgnoreCase(k)).get -> v } + } + override def createPartitions( db: String, table: String, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = withClient { requireTableExists(db, table) - client.createPartitions(db, table, parts, ignoreIfExists) + val lowerCasedParts = parts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) } override def dropPartitions( @@ -601,7 +628,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat ignoreIfNotExists: Boolean, purge: Boolean): Unit = withClient { requireTableExists(db, table) - client.dropPartitions(db, table, parts, ignoreIfNotExists, purge) + client.dropPartitions(db, table, parts.map(lowerCasePartitionSpec), ignoreIfNotExists, purge) } override def renamePartitions( @@ -609,21 +636,24 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = withClient { - client.renamePartitions(db, table, specs, newSpecs) + client.renamePartitions( + db, table, specs.map(lowerCasePartitionSpec), newSpecs.map(lowerCasePartitionSpec)) } override def alterPartitions( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { - client.alterPartitions(db, table, newParts) + val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + client.alterPartitions(db, table, lowerCasedParts) } override def getPartition( db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition = withClient { - client.getPartition(db, table, spec) + val part = client.getPartition(db, table, lowerCasePartitionSpec(spec)) + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) } /** @@ -633,7 +663,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, spec: TablePartitionSpec): Option[CatalogTablePartition] = withClient { - client.getPartitionOption(db, table, spec) + client.getPartitionOption(db, table, lowerCasePartitionSpec(spec)).map { part => + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + } } /** @@ -643,7 +675,45 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { - client.getPartitions(db, table, partialSpec) + client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + } + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { + val rawTable = client.getTable(db, table) + val catalogTable = restoreTableMetadata(rawTable) + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (nonPartitionPruningPredicates.nonEmpty) { + sys.error("Expected only partition pruning predicates: " + + predicates.reduceLeft(And)) + } + + val partitionSchema = catalogTable.partitionSchema + + if (predicates.nonEmpty) { + val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames)) + } + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) } + } else { + client.getPartitions(catalogTable).map { part => + part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames)) + } + } } // -------------------------------------------------------------------------- @@ -692,7 +762,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } object HiveExternalCatalog { - val DATASOURCE_PREFIX = "spark.sql.sources." + val SPARK_SQL_PREFIX = "spark.sql." + + val DATASOURCE_PREFIX = SPARK_SQL_PREFIX + "sources." val DATASOURCE_PROVIDER = DATASOURCE_PREFIX + "provider" val DATASOURCE_SCHEMA = DATASOURCE_PREFIX + "schema" val DATASOURCE_SCHEMA_PREFIX = DATASOURCE_SCHEMA + "." @@ -706,21 +778,20 @@ object HiveExternalCatalog { val DATASOURCE_SCHEMA_BUCKETCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "bucketCol." val DATASOURCE_SCHEMA_SORTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "sortCol." - val STATISTICS_PREFIX = "spark.sql.statistics." + val STATISTICS_PREFIX = SPARK_SQL_PREFIX + "statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" val STATISTICS_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." - def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } - } + val TABLE_PARTITION_PROVIDER = SPARK_SQL_PREFIX + "partitionProvider" + def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = { metadata.properties.get(DATASOURCE_PROVIDER) } def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(DATASOURCE_PREFIX) } + metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) } } // A persisted data source table always store its schema in the catalog. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 1625116803505..e303065127c3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -473,10 +473,8 @@ private[hive] trait HiveInspectors { case mi: StandardConstantMapObjectInspector => val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector) val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector) - val keyValues = mi.getWritableConstantValue.asScala.toSeq - val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray - val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray - val constant = ArrayBasedMapData(keys, values) + val keyValues = mi.getWritableConstantValue + val constant = ArrayBasedMapData(keyValues, keyUnwrapper, valueUnwrapper) _ => constant case li: StandardConstantListObjectInspector => val unwrapper = unwrapperFor(li.getListElementObjectInspector) @@ -655,10 +653,7 @@ private[hive] trait HiveInspectors { if (map == null) { null } else { - val keyValues = map.asScala.toSeq - val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray - val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray - ArrayBasedMapData(keys, values) + ArrayBasedMapData(map, keyUnwrapper, valueUnwrapper) } } else { null diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8410a2e4a47ca..624ab747e442f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{Partition => _, _} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.types._ @@ -44,8 +44,6 @@ import org.apache.spark.sql.types._ */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - private val client = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) @@ -78,11 +76,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, - options = table.storage.properties) + options = table.storage.properties, + catalogTable = Some(table)) - LogicalRelation( - dataSource.resolveRelation(), - catalogTable = Some(table)) + LogicalRelation(dataSource.resolveRelation(), catalogTable = Some(table)) } } @@ -104,7 +101,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - new Path(new Path(client.getDatabase(dbName).locationUri), tblName).toString + val dbLocation = sparkSession.sharedState.externalCatalog.getDatabase(dbName).locationUri + new Path(new Path(dbLocation), tblName).toString } def lookupRelation( @@ -129,19 +127,19 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } else { val qualifiedTable = MetastoreRelation( - qualifiedTableName.database, qualifiedTableName.name)(table, client, sparkSession) + qualifiedTableName.database, qualifiedTableName.name)(table, sparkSession) alias.map(a => SubqueryAlias(a, qualifiedTable, None)).getOrElse(qualifiedTable) } } private def getCached( tableIdentifier: QualifiedTableName, - pathsInMetastore: Seq[String], + pathsInMetastore: Seq[Path], metastoreRelation: MetastoreRelation, schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], expectedBucketSpec: Option[BucketSpec], - partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + partitionSchema: Option[StructType]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss @@ -153,12 +151,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // If we have the same paths, same schema, and same partition spec, // we will use the cached relation. val useCached = - relation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && + relation.location.rootPaths.toSet == pathsInMetastore.toSet && logical.schema.sameType(schemaInMetastore) && relation.bucketSpec == expectedBucketSpec && - relation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[PartitionDirectory]) - } + relation.partitionSchema == partitionSchema.getOrElse(StructType(Nil)) if (useCached) { Some(logical) @@ -197,61 +193,57 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. + val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) - val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into HadoopFsRelation, so predicates to Hive metastore - // are empty. - val partitions = metastoreRelation.getHiveQlPartitions().map { p => - val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { - case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) - }) - PartitionDirectory(values, location) - } - val partitionSpec = PartitionSpec(partitionSchema, partitions) - val partitionPaths = partitions.map(_.path.toString) - - // By convention (for example, see MetaStorePartitionedTableFileCatalog), the definition of a - // partitioned table's paths depends on whether that table has any actual partitions. - // Partitioned tables without partitions use the location of the table's base path. - // Partitioned tables with partitions use the locations of those partitions' data locations, - // _omitting_ the table's base path. - val paths = if (partitionPaths.isEmpty) { - Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + + val rootPaths: Seq[Path] = if (lazyPruningEnabled) { + Seq(metastoreRelation.hiveQlTable.getDataLocation) } else { - partitionPaths + // By convention (for example, see CatalogFileIndex), the definition of a + // partitioned table's paths depends on whether that table has any actual partitions. + // Partitioned tables without partitions use the location of the table's base path. + // Partitioned tables with partitions use the locations of those partitions' data + // locations,_omitting_ the table's base path. + val paths = metastoreRelation.getHiveQlPartitions().map { p => + new Path(p.getLocation) + } + if (paths.isEmpty) { + Seq(metastoreRelation.hiveQlTable.getDataLocation) + } else { + paths + } } val cached = getCached( tableIdentifier, - paths, + rootPaths, metastoreRelation, metastoreSchema, fileFormatClass, bucketSpec, - Some(partitionSpec)) - - val hadoopFsRelation = cached.getOrElse { - val fileCatalog = new MetaStorePartitionedTableFileCatalog( - sparkSession, - new Path(metastoreRelation.catalogTable.storage.locationUri.get), - partitionSpec) - - val inferredSchema = if (fileType.equals("parquet")) { - val inferredSchema = - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) - inferredSchema.map { inferred => - ParquetFileFormat.mergeMetastoreParquetSchema(metastoreSchema, inferred) - }.getOrElse(metastoreSchema) - } else { - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get + Some(partitionSchema)) + + val logicalRelation = cached.getOrElse { + val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong + val fileCatalog = { + val catalog = new CatalogFileIndex( + sparkSession, metastoreRelation.catalogTable, sizeInBytes) + if (lazyPruningEnabled) { + catalog + } else { + catalog.filterPartitions(Nil) // materialize all the partitions in memory + } } + val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet + val dataSchema = + StructType(metastoreSchema + .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) val relation = HadoopFsRelation( location = fileCatalog, partitionSchema = partitionSchema, - dataSchema = inferredSchema, + dataSchema = dataSchema, bucketSpec = bucketSpec, fileFormat = defaultSource, options = options)(sparkSession = sparkSession) @@ -261,12 +253,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log created } - hadoopFsRelation + logicalRelation } else { - val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + val rootPath = metastoreRelation.hiveQlTable.getDataLocation val cached = getCached(tableIdentifier, - paths, + Seq(rootPath), metastoreRelation, metastoreSchema, fileFormatClass, @@ -277,14 +269,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log LogicalRelation( DataSource( sparkSession = sparkSession, - paths = paths, + paths = rootPath.toString :: Nil, userSpecifiedSchema = Some(metastoreRelation.schema), bucketSpec = bucketSpec, options = options, className = fileType).resolveRelation(), catalogTable = Some(metastoreRelation.catalogTable)) - cachedDataSourceTables.put(tableIdentifier, created) created } @@ -372,34 +363,3 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } } - -/** - * An override of the standard HDFS listing based catalog, that overrides the partition spec with - * the information from the metastore. - * - * @param tableBasePath The default base path of the Hive metastore table - * @param partitionSpec The partition specifications from Hive metastore - */ -private[hive] class MetaStorePartitionedTableFileCatalog( - sparkSession: SparkSession, - tableBasePath: Path, - override val partitionSpec: PartitionSpec) - extends ListingFileCatalog( - sparkSession, - MetaStorePartitionedTableFileCatalog.getPaths(tableBasePath, partitionSpec), - Map.empty, - Some(partitionSpec.partitionColumns)) { -} - -private[hive] object MetaStorePartitionedTableFileCatalog { - /** Get the list of paths to list files in the for a metastore table */ - def getPaths(tableBasePath: Path, partitionSpec: PartitionSpec): Seq[Path] = { - // If there are no partitions currently specified then use base path, - // otherwise use the paths corresponding to the partitions. - if (partitionSpec.partitions.isEmpty) { - Seq(tableBasePath) - } else { - partitionSpec.partitions.map(_.path) - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 85ecf0ce70756..4f2910abfd216 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchTableException} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} @@ -57,7 +57,13 @@ private[sql] class HiveSessionCatalog( override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = { val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + if (db == globalTempViewManager.database) { + val relationAlias = alias.getOrElse(table) + globalTempViewManager.get(table).map { viewDef => + SubqueryAlias(relationAlias, viewDef, Some(name)) + }.getOrElse(throw new NoSuchTableException(db, table)) + } else if (name.database.isDefined || !tempTables.contains(table)) { val database = name.database.map(formatDatabaseName) val newName = name.copy(database = database, table = table) metastoreCatalog.lookupRelation(newName, alias) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala index 33f0ecff63529..da809cf991de2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala @@ -43,7 +43,6 @@ private[hive] case class MetastoreRelation( databaseName: String, tableName: String) (val catalogTable: CatalogTable, - @transient private val client: HiveClient, @transient private val sparkSession: SparkSession) extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation { @@ -59,7 +58,7 @@ private[hive] case class MetastoreRelation( Objects.hashCode(databaseName, tableName, output) } - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: client :: sparkSession :: Nil + override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil private def toHiveColumn(c: StructField): FieldSchema = { new FieldSchema(c.name, c.dataType.catalogString, c.getComment.orNull) @@ -146,11 +145,18 @@ private[hive] case class MetastoreRelation( // When metastore partition pruning is turned off, we cache the list of all partitions to // mimic the behavior of Spark < 1.5 - private lazy val allPartitions: Seq[CatalogTablePartition] = client.getPartitions(catalogTable) + private lazy val allPartitions: Seq[CatalogTablePartition] = { + sparkSession.sharedState.externalCatalog.listPartitions( + catalogTable.database, + catalogTable.identifier.table) + } def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - client.getPartitionsByFilter(catalogTable, predicates) + sparkSession.sharedState.externalCatalog.listPartitionsByFilter( + catalogTable.database, + catalogTable.identifier.table, + predicates) } else { allPartitions } @@ -234,8 +240,7 @@ private[hive] case class MetastoreRelation( val columnOrdinals = AttributeMap(attributes.zipWithIndex) override def inputFiles: Array[String] = { - val partLocations = client - .getPartitionsByFilter(catalogTable, Nil) + val partLocations = allPartitions .flatMap(_.storage.locationUri) .toArray if (partLocations.nonEmpty) { @@ -248,6 +253,6 @@ private[hive] case class MetastoreRelation( } override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName)(catalogTable, client, sparkSession) + MetastoreRelation(databaseName, tableName)(catalogTable, sparkSession) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 2a54163a04e9b..aaf30f41f29c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -149,8 +149,7 @@ class HadoopTableReader( * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, - Class[_ <: Deserializer]], + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[InternalRow] = { // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 984d23bb09dbd..569a9c11398ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -184,12 +184,12 @@ private[hive] trait HiveClient { * If no partition spec is specified, all partitions are returned. */ def getPartitions( - table: CatalogTable, + catalogTable: CatalogTable, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] /** Returns partitions filtered by predicates for the given table. */ def getPartitionsByFilter( - table: CatalogTable, + catalogTable: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] /** Loads a static partition into an existing table. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index dd33d750a4d45..84873bbbb81ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -37,6 +37,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} @@ -528,17 +529,21 @@ private[hive] class HiveClientImpl( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table) - spec match { + val parts = spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) } + HiveCatalogMetrics.incrementFetchedPartitions(parts.length) + parts } override def getPartitionsByFilter( table: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table) - shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + HiveCatalogMetrics.incrementFetchedPartitions(parts.length) + parts } override def listTables(dbName: String): Seq[String] = withHiveState { @@ -772,7 +777,7 @@ private[hive] class HiveClientImpl( val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => table.partitionColumnNames.contains(c.getName) } - if (table.schema.isEmpty) { + if (schema.isEmpty) { // This is a hack to preserve existing behavior. Before Spark 2.0, we do not // set a default serde here (this was done in Hive), and so if the user provides // an empty schema Hive would automatically populate the schema with a single diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 32387707612f4..4bbbd66132b75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -585,7 +586,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { logDebug(s"Hive metastore filter is '$filter'.") - getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + try { + getPartitionsByFilterMethod.invoke(hive, table, filter) + .asInstanceOf[JArrayList[Partition]] + } catch { + case e: InvocationTargetException => + // SPARK-18167 retry to investigate the flaky test. This should be reverted before + // the release is cut. + val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter)) + val full = Try(getAllPartitionsMethod.invoke(hive, table)) + logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess) + logError("getPartitionsByFilter failed, full fetch success = " + full.isSuccess) + throw e + } } partitions.asScala.toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 53bb3b93db738..c3c4e2925b90c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive.execution import java.io.IOException import java.net.URI import java.text.SimpleDateFormat -import java.util import java.util.{Date, Random} import scala.collection.JavaConverters._ @@ -36,6 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -291,6 +291,8 @@ case class InsertIntoHiveTable( Seq.empty[InternalRow] } + override def outputPartitioning: Partitioning = child.outputPartitioning + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 1025b8f70d9ff..50855e48bc8fe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ @@ -61,6 +62,8 @@ case class ScriptTransformation( override def producedAttributes: AttributeSet = outputSet -- inputSet + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration) : Iterator[InternalRow] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index e94f49ea81177..eba7aa386ade2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -83,11 +83,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, bucketId, dataSchema, context) + new OrcOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) } } } @@ -210,15 +210,24 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - private[this] val conf = context.getConfiguration + override val path: String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + // It has the `.orc` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "stream" in ORC format. + new Path(stagingDir, fileNamePrefix + compressionExtension + ".orc").toString + } - private[this] val serializer = new OrcSerializer(dataSchema, conf) + private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. @@ -226,23 +235,10 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val partition = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - val compressionExtension = { - val name = conf.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") - } - // It has the `.orc` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "stream" in ORC format. - val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" - new OrcOutputFormat().getRecordWriter( - new Path(path, filename).getFileSystem(conf), - conf.asInstanceOf[JobConf], - new Path(path, filename).toString, + new Path(path).getFileSystem(context.getConfiguration), + context.getConfiguration.asInstanceOf[JobConf], + path, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } @@ -313,7 +309,17 @@ private[orc] object OrcRelation extends HiveInspectors { def setRequiredColumns( conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + val caseInsensitiveFieldMap: Map[String, Int] = physicalSchema.fieldNames + .zipWithIndex + .map(f => (f._1.toLowerCase, f._2)) + .toMap + val ids = requestedSchema.map { a => + val exactMatch: Option[Int] = physicalSchema.getFieldIndex(a.name) + val res = exactMatch.getOrElse( + caseInsensitiveFieldMap.getOrElse(a.name, + throw new IllegalArgumentException(s"""Field "$a.name" does not exist."""))) + res: Integer + } val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 7d4ef6f26a600..fc35304c80ecc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils @@ -317,4 +320,40 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("DROP TABLE cachedTable") } + + test("cache a table using CatalogFileIndex") { + withTable("test") { + sql("CREATE TABLE test(i int) PARTITIONED BY (p int) STORED AS parquet") + val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") + val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) + + val dataSchema = StructType(tableMeta.schema.filterNot { f => + tableMeta.partitionColumnNames.contains(f.name) + }) + val relation = HadoopFsRelation( + location = catalogFileIndex, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) + + val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) + spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) + + assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) + + val sameCatalog = new CatalogFileIndex(spark, tableMeta, 0) + val sameRelation = HadoopFsRelation( + location = sameCatalog, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) + val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) + + assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 9ce3338647398..d13e29b3029b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,10 +30,12 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -class HiveDDLCommandSuite extends PlanTest { +class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton { val parser = TestHive.sessionState.sqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -556,4 +558,38 @@ class HiveDDLCommandSuite extends PlanTest { assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") } + test("Test the default fileformat for Hive-serde tables") { + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + withSQLConf("hive.default.fileformat" -> "parquet") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + val input = desc.storage.inputFormat + val output = desc.storage.outputFormat + val serde = desc.storage.serde + assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + } + + test("table name with schema") { + // regression test for SPARK-11778 + spark.sql("create schema usrdb") + spark.sql("create table usrdb.test(c int)") + spark.read.table("usrdb.test") + spark.sql("drop table usrdb.test") + spark.sql("drop schema usrdb") + } + + test("SPARK-15887: hive-site.xml should be loaded") { + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + assert(hiveClient.getConf("hive.in.test", "") == "true") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala deleted file mode 100644 index 6477974fe713a..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton - -// TODO ideally we should put the test suite into the package `sql`, as -// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't -// support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import spark.implicits._ - import spark.sql - - private var testData: DataFrame = _ - - override def beforeAll() { - super.beforeAll() - testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") - testData.createOrReplaceTempView("mytable") - } - - override def afterAll(): Unit = { - try { - spark.catalog.dropTempView("mytable") - } finally { - super.afterAll() - } - } - - test("rollup") { - checkAnswer( - testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() - ) - - checkAnswer( - testData.rollup("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() - ) - } - - test("cube") { - checkAnswer( - testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() - ) - - checkAnswer( - testData.cube("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with cube").collect() - ) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 26c2549820de6..efa0beb85030b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.dsl.expressions._ /** * Test suite for the [[HiveExternalCatalog]]. @@ -43,4 +44,12 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { externalCatalog.client.reset() } + import utils._ + + test("list partitions by filter") { + val catalog = newBasicCatalog() + val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1)) + assert(selectedPartitions.length == 1) + assert(selectedPartitions.head.spec == part1.spec) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 3414f5e0409a1..6e887d95c0f09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils /** @@ -59,4 +60,62 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi } } } + + def testCaching(pruningEnabled: Boolean): Unit = { + test(s"partitioned table is cached when partition pruning is $pruningEnabled") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> pruningEnabled.toString) { + withTable("test") { + withTempDir { dir => + spark.range(5).selectExpr("id", "id as f1", "id as f2").write + .partitionBy("f1", "f2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create external table test (id long) + |partitioned by (f1 int, f2 int) + |stored as parquet + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql("msck repair table test") + + val df = spark.sql("select * from test") + assert(sql("select * from test").count() == 5) + + def deleteRandomFile(): Unit = { + val p = new Path(spark.table("test").inputFiles.head) + assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, true)) + } + + // Delete a file, then assert that we tried to read it. This means the table was cached. + deleteRandomFile() + val e = intercept[SparkException] { + sql("select * from test").count() + } + assert(e.getMessage.contains("FileNotFoundException")) + + // Test refreshing the cache. + spark.catalog.refreshTable("test") + assert(sql("select * from test").count() == 4) + assert(spark.table("test").inputFiles.length == 4) + + // Test refresh by path separately since it goes through different code paths than + // refreshTable does. + deleteRandomFile() + spark.catalog.cacheTable("test") + spark.catalog.refreshByPath("/some-invalid-path") // no-op + val e2 = intercept[SparkException] { + sql("select * from test").count() + } + assert(e2.getMessage.contains("FileNotFoundException")) + spark.catalog.refreshByPath(dir.getAbsolutePath) + assert(sql("select * from test").count() == 3) + } + } + } + } + } + + for (pruningEnabled <- Seq(true, false)) { + testCaching(pruningEnabled) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d9ce1c3dc18ff..e3ddaf725424d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -370,17 +370,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef assert(cause.getMessage.contains("insertInto() can't be used together with partitionBy().")) } - test("InsertIntoTable#resolved should include dynamic partitions") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") - - val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, - Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) - assert(!logical.resolved, "Should not resolve: missing partition data") - } - } - testPartitionedTable( "SPARK-16036: better error message when insert into a table with mismatch schema") { tableName => @@ -409,8 +398,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12") - // c is defined twice. Parser will complain. - intercept[ParseException] { + // c is defined twice. Analyzer will complain. + intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7cc6179d44977..eaa67d370db37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1321,20 +1321,32 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sharedState.externalCatalog.getTable("default", "t") }.getMessage assert(e.contains(s"Could not read schema from the hive metastore because it is corrupted")) + + withDebugMode { + val tableMeta = sharedState.externalCatalog.getTable("default", "t") + assert(tableMeta.identifier == TableIdentifier("t", Some("default"))) + assert(tableMeta.properties(DATASOURCE_PROVIDER) == "json") + } } finally { hiveClient.dropTable("default", "t", ignoreIfNotExists = true, purge = true) } } test("should keep data source entries in table properties when debug mode is on") { - val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) - try { - sparkSession.sparkContext.conf.set(DEBUG_MODE, true) + withDebugMode { val newSession = sparkSession.newSession() newSession.sql("CREATE TABLE abc(i int) USING json") val tableMeta = newSession.sessionState.catalog.getTableMetadata(TableIdentifier("abc")) assert(tableMeta.properties(DATASOURCE_SCHEMA_NUMPARTS).toInt == 1) assert(tableMeta.properties(DATASOURCE_PROVIDER) == "json") + } + } + + private def withDebugMode(f: => Unit): Unit = { + val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) + try { + sparkSession.sparkContext.conf.set(DEBUG_MODE, true) + f } finally { sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala index 2f3055dcac4c5..91ff711445e82 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -17,23 +17,39 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -class MetastoreRelationSuite extends SparkFunSuite { +class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("makeCopy and toJSON should work") { val table = CatalogTable( identifier = TableIdentifier("test", Some("db")), tableType = CatalogTableType.VIEW, storage = CatalogStorageFormat.empty, schema = StructType(StructField("a", IntegerType, true) :: Nil)) - val relation = MetastoreRelation("db", "test")(table, null, null) + val relation = MetastoreRelation("db", "test")(table, null) // No exception should be thrown relation.makeCopy(Array("db", "test")) // No exception should be thrown relation.toJSON } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("CREATE TABLE bar AS SELECT * FROM foo group by id") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala new file mode 100644 index 0000000000000..5f16960fb1496 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class PartitionProviderCompatibilitySuite + extends QueryTest with TestHiveSingleton with SQLTestUtils { + + private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol").write + .partitionBy("partCol") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create table $tableName (fieldOne long, partCol int) + |using parquet + |options (path "${dir.getAbsolutePath}") + |partitioned by (partCol)""".stripMargin) + } + + private def verifyIsLegacyTable(tableName: String): Unit = { + val unsupportedCommands = Seq( + s"ALTER TABLE $tableName ADD PARTITION (partCol=1) LOCATION '/foo'", + s"ALTER TABLE $tableName PARTITION (partCol=1) RENAME TO PARTITION (partCol=2)", + s"ALTER TABLE $tableName PARTITION (partCol=1) SET LOCATION '/foo'", + s"ALTER TABLE $tableName DROP PARTITION (partCol=1)", + s"DESCRIBE $tableName PARTITION (partCol=1)", + s"SHOW PARTITIONS $tableName") + + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + for (cmd <- unsupportedCommands) { + val e = intercept[AnalysisException] { + spark.sql(cmd) + } + assert(e.getMessage.contains("partition metadata is not stored in the Hive metastore"), e) + } + } + } + + test("convert partition provider to hive with repair table") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + setupPartitionedDatasourceTable("test", dir) + assert(spark.sql("select * from test").count() == 5) + } + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + verifyIsLegacyTable("test") + spark.sql("msck repair table test") + spark.sql("show partitions test").count() // check we are a new table + + // sanity check table performance + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + } + } + } + } + + test("when partition management is enabled, new tables have partition provider hive") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + setupPartitionedDatasourceTable("test", dir) + spark.sql("show partitions test").count() // check we are a new table + assert(spark.sql("select * from test").count() == 0) // needs repair + spark.sql("msck repair table test") + assert(spark.sql("select * from test").count() == 5) + } + } + } + } + + test("when partition management is disabled, new tables have no partition provider") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + setupPartitionedDatasourceTable("test", dir) + verifyIsLegacyTable("test") + assert(spark.sql("select * from test").count() == 5) + } + } + } + } + + test("when partition management is disabled, we preserve the old behavior even for new tables") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + setupPartitionedDatasourceTable("test", dir) + spark.sql("show partitions test").count() // check we are a new table + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 0) + } + // disabled + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + val e = intercept[AnalysisException] { + spark.sql(s"show partitions test") + } + assert(e.getMessage.contains("filesource partition management is disabled")) + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 5) + } + // then enabled again + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 0) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala new file mode 100644 index 0000000000000..d8e31c4e39a5c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.datasources.FileStatusCache +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class PartitionedTablePerfStatsSuite + extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { + + override def beforeEach(): Unit = { + super.beforeEach() + FileStatusCache.resetForTesting() + } + + override def afterEach(): Unit = { + super.afterEach() + FileStatusCache.resetForTesting() + } + + private case class TestSpec(setupTable: (String, File) => Unit, isDatasourceTable: Boolean) + + /** + * Runs a test against both converted hive and native datasource tables. The test can use the + * passed TestSpec object for setup and inspecting test parameters. + */ + private def genericTest(testName: String)(fn: TestSpec => Unit): Unit = { + test("hive table: " + testName) { + fn(TestSpec(setupPartitionedHiveTable, false)) + } + test("datasource table: " + testName) { + fn(TestSpec(setupPartitionedDatasourceTable, true)) + } + } + + private def setupPartitionedHiveTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + .partitionBy("partCol1", "partCol2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create external table $tableName (fieldOne long) + |partitioned by (partCol1 int, partCol2 int) + |stored as parquet + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql(s"msck repair table $tableName") + } + + private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + .partitionBy("partCol1", "partCol2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create table $tableName (fieldOne long, partCol1 int, partCol2 int) + |using parquet + |options (path "${dir.getAbsolutePath}") + |partitioned by (partCol1, partCol2)""".stripMargin) + spark.sql(s"msck repair table $tableName") + } + + genericTest("partitioned pruned table reports only selected files") { spec => + assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") + withTable("test") { + withTempDir { dir => + spec.setupTable("test", dir) + val df = spark.sql("select * from test") + assert(df.count() == 5) + assert(df.inputFiles.length == 5) // unpruned + + val df2 = spark.sql("select * from test where partCol1 = 3 or partCol2 = 4") + assert(df2.count() == 2) + assert(df2.inputFiles.length == 2) // pruned, so we have less files + + val df3 = spark.sql("select * from test where PARTCOL1 = 3 or partcol2 = 4") + assert(df3.count() == 2) + assert(df3.inputFiles.length == 2) + + val df4 = spark.sql("select * from test where partCol1 = 999") + assert(df4.count() == 0) + assert(df4.inputFiles.length == 0) + + // TODO(ekl) enable for hive tables as well once SPARK-17983 is fixed + if (spec.isDatasourceTable) { + val df5 = spark.sql("select * from test where fieldOne = 4") + assert(df5.count() == 1) + assert(df5.inputFiles.length == 5) + } + } + } + } + + genericTest("lazy partition pruning reads only necessary partition data") { spec => + withSQLConf( + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "0") { + withTable("test") { + withTempDir { dir => + spec.setupTable("test", dir) + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 = 999").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 < 2").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 < 3").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 3) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 3) + + // should read all + HiveCatalogMetrics.reset() + spark.sql("select * from test").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + // read all should not be cached + HiveCatalogMetrics.reset() + spark.sql("select * from test").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + // cache should be disabled + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + } + } + } + } + + genericTest("lazy partition pruning with file status caching enabled") { spec => + withSQLConf( + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "9999999") { + withTable("test") { + withTempDir { dir => + spec.setupTable("test", dir) + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 999").count() == 0) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 < 3").count() == 3) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 3) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 2) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 3) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 5) + } + } + } + } + + genericTest("file status caching respects refresh table and refreshByPath") { spec => + withSQLConf( + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "9999999") { + withTable("test") { + withTempDir { dir => + spec.setupTable("test", dir) + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + + HiveCatalogMetrics.reset() + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + + spark.catalog.cacheTable("test") + HiveCatalogMetrics.reset() + spark.catalog.refreshByPath(dir.getAbsolutePath) + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + } + } + } + } + + genericTest("file status cache respects size limit") { spec => + withSQLConf( + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "1" /* 1 byte */) { + withTable("test") { + withTempDir { dir => + spec.setupTable("test", dir) + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 10) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) + } + } + } + } + + test("hive table: files read and cached when filesource partition management is off") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedHiveTable("test", dir) + + // We actually query the partitions from hive each time the table is resolved in this + // mode. This is kind of terrible, but is needed to preserve the legacy behavior + // of doing plan cache validation based on the entire partition set. + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 999").count() == 0) + // 5 from table resolution, another 5 from InMemoryFileIndex + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 10) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + } + } + } + } + + test("datasource table: all partition data cached in memory when partition management is off") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 999").count() == 0) + + // not using metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + + // reads and caches all the files initially + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 99dd080683d40..4f5ebc3d838b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} @@ -310,39 +310,50 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - test("test table-level statistics for data source table created in HiveExternalCatalog") { - val parquetTable = "parquetTable" - withTable(parquetTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) USING PARQUET") - val catalogTable = spark.sessionState.catalog.getTableMetadata(TableIdentifier(parquetTable)) - assert(DDLUtils.isDatasourceTable(catalogTable)) + private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = { + test("test table-level statistics for " + tableDescription) { + val parquetTable = "parquetTable" + withTable(parquetTable) { + sql(createTableCmd) + val catalogTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(parquetTable)) + assert(DDLUtils.isDatasourceTable(catalogTable)) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) - // noscan won't count the number of rows - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + // noscan won't count the number of rows + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") + val fetchedStats2 = checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) - // without noscan, we count the number of rows - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(1000)) - assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) + // without noscan, we count the number of rows + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") + val fetchedStats3 = checkTableStats( + parquetTable, + isDataSourceTable = true, + hasSizeInBytes = true, + expectedRowCounts = Some(1000)) + assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) + } } } + testUpdatingTableStats( + "data source table created in HiveExternalCatalog", + "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET") + + testUpdatingTableStats( + "partitioned data source table", + "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET PARTITIONED BY (key)") + test("statistics collection of a table with zero column") { val table_no_cols = "table_no_cols" withTable(table_no_cols) { @@ -358,53 +369,187 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - test("generate column-level statistics and load them from hive metastore") { + private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { + val tableName = "tbl" + var statsBeforeUpdate: Statistics = null + var statsAfterUpdate: Statistics = null + withTable(tableName) { + val tableIndent = TableIdentifier(tableName, Some("default")) + val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + sql(s"CREATE TABLE $tableName (key int) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + // Table lookup will make the table cached. + catalog.lookupRelation(tableIndent) + statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + + sql(s"INSERT INTO $tableName SELECT 2") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + catalog.lookupRelation(tableIndent) + statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + } + (statsBeforeUpdate, statsAfterUpdate) + } + + test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + } + + test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsBeforeUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsAfterUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + + private lazy val (testDataFrame, expectedColStatsSeq) = { import testImplicits._ val intSeq = Seq(1, 2) val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), booleanSeq(i)) + (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) } - val tableName = "table" - withTable(tableName) { - val df = data.toDF("c1", "c2", "c3") - df.write.format("parquet").saveAsTable(tableName) - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) + val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") + val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toInt)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) } + (f, colStat) + } + (df, expectedColStatsSeq) + } - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) + private def checkColStats( + tableName: String, + isDataSourceTable: Boolean, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val readback = spark.table(tableName) + val stats = readback.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") + rel.catalogTable.stats.get + case rel: LogicalRelation => + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") + rel.catalogTable.get.stats.get + } + assert(stats.length == 1) + val columnStats = stats.head.colStats + assert(columnStats.size == expectedColStatsSeq.length) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = columnStats(field.name), + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } + + test("generate and load column-level stats for data source table") { + val dsTable = "dsTable" + withTable(dsTable) { + testDataFrame.write.format("parquet").saveAsTable(dsTable) + sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) + } + } + + test("generate and load column-level stats for hive serde table") { + val hTable = "hTable" + val tmp = "tmp" + withTable(hTable, tmp) { + testDataFrame.write.format("parquet").saveAsTable(tmp) + sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") + sql(s"INSERT INTO $hTable SELECT * FROM $tmp") + sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) + } + } + + // When caseSensitive is on, for columns with only case difference, they are different columns + // and we should generate column stats for all of them. + private def checkCaseSensitiveColStats(columnName: String): Unit = { + val tableName = "tbl" + withTable(tableName) { + val column1 = columnName.toLowerCase + val column2 = columnName.toUpperCase + withSQLConf("spark.sql.caseSensitive" -> "true") { + sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1, 3.0") + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + assert(columnStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = columnStats(column1), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, + dataType = DoubleType, + colStat = columnStats(column2), + expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), rsd = spark.sessionState.conf.ndvMaxError) + rel } - rel + assert(relations.size == 1) } - assert(relations.size == 1) } } + test("check column statistics for case sensitive column names") { + checkCaseSensitiveColStats(columnName = "c1") + } + + test("check column statistics for case sensitive non-ascii column names") { + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkCaseSensitiveColStats(columnName = "列c") + // scalastyle:on + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 2c772ce2155ef..46ed18c70fb56 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType @@ -336,28 +337,6 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("show columns") { - checkAnswer( - sql("SHOW COLUMNS IN parquet_tab3"), - Row("col1") :: Row("col 2") :: Nil) - - checkAnswer( - sql("SHOW COLUMNS IN default.parquet_tab3"), - Row("col1") :: Row("col 2") :: Nil) - - checkAnswer( - sql("SHOW COLUMNS IN parquet_tab3 FROM default"), - Row("col1") :: Row("col 2") :: Nil) - - checkAnswer( - sql("SHOW COLUMNS IN parquet_tab4 IN default"), - Row("price") :: Row("qty") :: Row("year") :: Row("month") :: Nil) - - val message = intercept[NoSuchTableException] { - sql("SHOW COLUMNS IN badtable FROM default") - }.getMessage - assert(message.contains("'badtable' not found in database")) - } test("show partitions - show everything") { checkAnswer( @@ -436,10 +415,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto .mode(SaveMode.Overwrite) .saveAsTable("part_datasrc") - val message1 = intercept[AnalysisException] { - sql("SHOW PARTITIONS part_datasrc") - }.getMessage - assert(message1.contains("is not allowed on a datasource table")) + assert(sql("SHOW PARTITIONS part_datasrc").count() == 3) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 80e75aa898c38..13ceed7c79e35 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -167,7 +167,7 @@ abstract class HiveComparisonTest // and does not return it as a query answer. case _: SetCommand => Seq("0") case _: ExplainCommand => answer - case _: DescribeTableCommand | ShowColumnsCommand(_) => + case _: DescribeTableCommand | ShowColumnsCommand(_, _) => // Filter out non-deterministic lines and lines which do not have actual results but // can introduce problems because of the way Hive formats these lines. // Then, remove empty lines. Do not sort the results. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3d1712e4354c0..e9268a922cf54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -200,9 +200,8 @@ class HiveDDLSuite val message = intercept[AnalysisException] { sql(s"ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-09', unknownCol='12')") } - assert(message.getMessage.contains( - "Partition spec is invalid. The spec (ds, unknowncol) must be contained within the " + - "partition spec (ds, hr) defined in table '`default`.`exttable_with_partitions`'")) + assert(message.getMessage.contains("unknownCol is not a valid partition column in table " + + "`default`.`exttable_with_partitions`")) sql( s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2b945dbbe03dd..6fbbed1d47e04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.net.URI import java.sql.Timestamp import java.util.{Locale, TimeZone} @@ -954,7 +955,8 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd .mkString("/") // Loads partition data to a temporary table to verify contents - val path = s"${sparkSession.getWarehousePath}/dynamic_part_table/$partFolder/part-00000" + val warehousePathFile = new URI(sparkSession.getWarehousePath()).getPath + val path = s"$warehousePathFile/dynamic_part_table/$partFolder/part-00000" sql("DROP TABLE IF EXISTS dp_verify") sql("CREATE TABLE dp_verify(intcol INT)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala new file mode 100644 index 0000000000000..cdbc26cd5c576 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StructType + +class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil + } + + test("PruneFileSourcePartitions should not change the output of LogicalRelation") { + withTable("test") { + withTempDir { dir => + sql( + s""" + |CREATE EXTERNAL TABLE test(i int) + |PARTITIONED BY (p int) + |STORED AS parquet + |LOCATION '${dir.getAbsolutePath}'""".stripMargin) + + val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") + val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) + + val dataSchema = StructType(tableMeta.schema.filterNot { f => + tableMeta.partitionColumnNames.contains(f.name) + }) + val relation = HadoopFsRelation( + location = catalogFileIndex, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) + + val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze + + val optimized = Optimize.execute(query) + assert(optimized.missingInput.isEmpty) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6f2a16662bf10..2735d3a5267e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.execution +import java.io.{File, PrintWriter} +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import scala.sys.process.{Process, ProcessLogger} import scala.util.Try +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql._ @@ -65,6 +68,22 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ import spark.implicits._ + test("query global temp view") { + val df = Seq(1).toDF("i1") + df.createGlobalTempView("tbl1") + val global_temp_db = spark.conf.get("spark.sql.globalTempDatabase") + checkAnswer(spark.sql(s"select * from ${global_temp_db}.tbl1"), Row(1)) + spark.sql(s"drop view ${global_temp_db}.tbl1") + } + + test("non-existent global temp view") { + val global_temp_db = spark.conf.get("spark.sql.globalTempDatabase") + val message = intercept[AnalysisException] { + spark.sql(s"select * from ${global_temp_db}.nonexistentview") + }.getMessage + assert(message.contains("Table or view not found")) + } + test("script") { val scriptFilePath = getTestResourcePath("test_script.sh") if (testCommandAvailable("bash") && testCommandAvailable("echo | sed")) { @@ -355,7 +374,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "# Partition Information", "# col_name", "Detailed Partition Information CatalogPartition(", - "Partition Values: [Us, 1]", + "Partition Values: [c=Us, d=1]", "Storage(Location:", "Partition Parameters") @@ -396,10 +415,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write .partitionBy("d") .saveAsTable("datasource_table") - val m4 = intercept[AnalysisException] { - sql("DESC datasource_table PARTITION (d=2)") - }.getMessage() - assert(m4.contains("DESC PARTITION is not allowed on a datasource table")) + + sql("DESC datasource_table PARTITION (d=0)") val m5 = intercept[AnalysisException] { spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") @@ -492,7 +509,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { def checkRelation( tableName: String, - isDataSourceParquet: Boolean, + isDataSourceTable: Boolean, format: String, userSpecifiedLocation: Option[String] = None): Unit = { val relation = EliminateSubqueryAliases( @@ -501,7 +518,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) relation match { case LogicalRelation(r: HadoopFsRelation, _, _) => - if (!isDataSourceParquet) { + if (!isDataSourceTable) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") @@ -514,7 +531,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(catalogTable.provider.get === format) case r: MetastoreRelation => - if (isDataSourceParquet) { + if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") @@ -524,8 +541,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(r.catalogTable.storage.locationUri.get === location) case None => // OK. } - // Also make sure that the format is the desired format. + // Also make sure that the format and serde are as desired. assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + val serde = catalogTable.storage.serde.get + format match { + case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) + case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) + case _ => assert(serde.toLowerCase.contains(format)) + } } // When a user-specified location is defined, the table type needs to be EXTERNAL. @@ -587,6 +611,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("CTAS with default fileformat") { + val table = "ctas1" + val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + withSQLConf("hive.default.fileformat" -> "textfile") { + withTable(table) { + sql(ctas) + // We should use parquet here as that is the default datasource fileformat. The default + // datasource file format is controlled by `spark.sql.sources.default` configuration. + // This testcase verifies that setting `hive.default.fileformat` has no impact on + // the target table's fileformat in case of CTAS. + assert(sessionState.conf.defaultDataSourceName === "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + } + } + withSQLConf("spark.sql.sources.default" -> "orc") { + withTable(table) { + sql(ctas) + checkRelation(tableName = table, isDataSourceTable = true, format = "orc") + } + } + } + } + test("CTAS without serde with location") { withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { withTempDir { dir => @@ -1886,6 +1934,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-17796 Support wildcard character in filename for LOAD DATA LOCAL INPATH") { + withTempDir { dir => + for (i <- 1 to 3) { + Files.write(s"$i", new File(s"$dir/part-r-0000$i"), StandardCharsets.UTF_8) + } + for (i <- 5 to 7) { + Files.write(s"$i", new File(s"$dir/part-s-0000$i"), StandardCharsets.UTF_8) + } + + withTable("load_t") { + sql("CREATE TABLE load_t (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$dir/*part-r*' INTO TABLE load_t") + checkAnswer(sql("SELECT * FROM load_t"), Seq(Row("1"), Row("2"), Row("3"))) + + val m = intercept[AnalysisException] { + sql("LOAD DATA LOCAL INPATH '/non-exist-folder/*part*' INTO TABLE load_t") + }.getMessage + assert(m.contains("LOAD DATA input path does not exist")) + + val m2 = intercept[AnalysisException] { + sql(s"LOAD DATA LOCAL INPATH '$dir*/*part*' INTO TABLE load_t") + }.getMessage + assert(m2.contains("LOAD DATA input path allows only filename wildcard")) + } + } + } + def testCommandAvailable(command: String): Boolean = { val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) attempt.isSuccess && attempt.get == 0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index a8e81d7a3c42a..0e837766e2ea4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType @@ -135,5 +136,8 @@ private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExe throw new IllegalArgumentException("intentional exception") } } + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b2ee49c441ef2..ecb5972984523 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -474,6 +474,28 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("converted ORC table supports resolving mixed case field") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { + withTable("dummy_orc") { + withTempPath { dir => + val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") + df.write + .partitionBy("partitionValue") + .mode("overwrite") + .orc(dir.getAbsolutePath) + + spark.sql(s""" + |create external table dummy_orc (id long, valueField long) + |partitioned by (partitionValue int) + |stored as orc + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql(s"msck repair table dummy_orc") + checkAnswer(spark.sql("select * from dummy_orc"), df) + } + } + } + } + test("SPARK-14962 Produce correct results on array type with isnotnull") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(Array(i))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2f6d9fb96b825..9fc62a389db4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -175,7 +175,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") .createOrReplaceTempView("jt_array") - setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true) + assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") } override def afterAll(): Unit = { @@ -187,7 +187,6 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - setConf(HiveUtils.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { @@ -586,6 +585,23 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkAnswer( sql("SELECT * FROM test_added_partitions"), Seq(("foo", 0), ("bar", 0), ("baz", 1)).toDF("a", "b")) + + // Check it with pruning predicates + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 0"), + Seq(("foo", 0), ("bar", 0)).toDF("a", "b")) + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 1"), + Seq(("baz", 1)).toDF("a", "b")) + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 2"), + Seq[(String, Int)]().toDF("a", "b")) + + // Also verify the inputFiles implementation + assert(sql("select * from test_added_partitions").inputFiles.length == 2) + assert(sql("select * from test_added_partitions where b = 0").inputFiles.length == 1) + assert(sql("select * from test_added_partitions where b = 1").inputFiles.length == 1) + assert(sql("select * from test_added_partitions where b = 2").inputFiles.length == 0) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 3ff85176de10e..d9ddcbd57ca83 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -235,7 +236,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private def testBucketing( bucketSpecLeft: Option[BucketSpec], bucketSpecRight: Option[BucketSpec], - joinColumns: Seq[String], + joinType: String = "inner", + joinCondition: (DataFrame, DataFrame) => Column, shuffleLeft: Boolean, shuffleRight: Boolean, sortLeft: Boolean = true, @@ -268,12 +270,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val t1 = spark.table("bucketed_table1") val t2 = spark.table("bucketed_table2") - val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + val joined = t1.join(t2, joinCondition(t1, t2), joinType) // First check the result is corrected. checkAnswer( joined.sort("bucketed_table1.k", "bucketed_table2.k"), - df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k")) assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec] @@ -297,56 +299,102 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } - private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = { joinCols.map(col => left(col) === right(col)).reduce(_ && _) } test("avoid shuffle when join 2 bucketed tables") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } test("only shuffle one side when join bucketed table and non-bucketed table") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = None, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket number") { val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i")), + shuffleLeft = false, + shuffleRight = true + ) } test("shuffle when join keys are not equal to bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("j")), + shuffleLeft = true, + shuffleRight = true + ) } test("shuffle when join 2 bucketed tables with bucketing disabled") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = true, + shuffleRight = true + ) } } test("avoid shuffle and sort when bucket and sort columns are join keys") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) testBucketing( - bucketSpec, bucketSpec, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = false + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false ) } @@ -354,9 +402,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = false + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false ) } @@ -364,9 +416,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = true + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = true ) } @@ -374,9 +430,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = true + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = true ) } @@ -408,11 +468,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } + test("SPARK-17698 Join predicates should not contain filter clauses") { + val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i"))) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinType = "fullouter", + joinCondition = (left: DataFrame, right: DataFrame) => { + val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _) + val filterLeft = left("i") === Literal("1") + val filterRight = right("i") === Literal("1") + joinPredicates && filterLeft && filterRight + }, + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val tableDir = new File(hiveContext - .sparkSession.getWarehousePath, "bucketed_table") + val warehouseFilePath = new URI(hiveContext.sparkSession.getWarehousePath).getPath + val tableDir = new File(warehouseFilePath, "bucketed_table") Utils.deleteRecursively(tableDir) df1.write.parquet(tableDir.getAbsolutePath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 997445114ba58..2eafe18b85844 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -54,11 +54,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) } - test("write bucketed data to unsupported data source") { - val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") - intercept[SparkException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) - } - test("write bucketed data using save()") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index 5a8a7f0ab5d7b..731540db17eeb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -39,17 +40,19 @@ class CommitFailureTestSource extends SimpleTextSource { dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } + override val path: String = new Path(stagingDir, fileNamePrefix).toString + override def write(row: Row): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 906de6bbcbee5..9896b9bde99c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} @@ -51,11 +51,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration) new OutputWriterFactory { override def newInstance( - path: String, - bucketId: Option[Int], + stagingDir: String, + fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) } } } @@ -120,9 +120,14 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { +class SimpleTextOutputWriter( + stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext) + extends OutputWriter { + + override val path: String = new Path(stagingDir, fileNamePrefix).toString + private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) + new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -136,19 +141,15 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends } } -class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() +class AppendingTextOutputFormat(stagingDir: Path, fileNamePrefix: String) + extends TextOutputFormat[NullWritable, Text] { + val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") + new Path(stagingDir, fileNamePrefix) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ea4e1160b7672..55e4a833b6707 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1059,9 +1059,11 @@ private[spark] class Client( } catch { case e: ApplicationNotFoundException => logError(s"Application $appId not found.") + cleanupStagingDir(appId) return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) case NonFatal(e) => logError(s"Failed to contact YARN for application $appId.", e) + // Don't necessarily clean up staging dir because status is unknown return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState @@ -1179,7 +1181,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.3-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d245acf49aa91..99fb58a28934a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -242,7 +242,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.3-src.zip", + s"$sparkHome/python/lib/py4j-0.10.4-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),